diff --git a/common/common.go b/common/common.go index 2e395bba913..ae2e5273f7f 100644 --- a/common/common.go +++ b/common/common.go @@ -696,3 +696,58 @@ func SetIfZero[T comparable](p *T, def T) bool { *p = def return true } + +var ( + contextKeys []any + contextKeysMu sync.RWMutex +) + +// RegisterContextKey registers a key to be captured by FreezeContext +func RegisterContextKey(key any) { + contextKeysMu.Lock() + defer contextKeysMu.Unlock() + if !slices.Contains(contextKeys, key) { + contextKeys = append(contextKeys, key) + } +} + +// FrozenContext holds captured context values +type FrozenContext map[any]any + +// FreezeContext captures values from the context for registered keys +func FreezeContext(ctx context.Context) FrozenContext { + contextKeysMu.RLock() + defer contextKeysMu.RUnlock() + + values := make(FrozenContext, len(contextKeys)) + for _, key := range contextKeys { + if val := ctx.Value(key); val != nil { + values[key] = val + } + } + return values +} + +// ThawContext creates a new context from the frozen context using context.Background() as parent +func ThawContext(fc FrozenContext) context.Context { + return MergeContext(context.Background(), fc) +} + +// MergeContext adds the frozen values to an existing context +func MergeContext(ctx context.Context, fc FrozenContext) context.Context { + return &mergedContext{Context: ctx, frozen: fc} +} + +// mergedContext is a context that has merged values from a frozen context and a parent context. +// frozen values are stored in FrozenContext instead of nested context.WithValue because of the performance of calling WithValue N+ times on messages being frozen +type mergedContext struct { + context.Context //nolint:containedctx // mergedContext implements context.Context + frozen FrozenContext +} + +func (m *mergedContext) Value(key any) any { + if val, ok := m.frozen[key]; ok { + return val + } + return m.Context.Value(key) +} diff --git a/common/common_test.go b/common/common_test.go index 278a4582085..249044a967b 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -1,6 +1,7 @@ package common import ( + "context" "errors" "fmt" "net/http" @@ -691,3 +692,36 @@ func TestSetIfZero(t *testing.T) { assert.True(t, changed, "SetIfZero should change a zero value") assert.Equal(t, "world", s, "SetIfZero should change a zero value") } + +func TestContextFunctions(t *testing.T) { + t.Parallel() + + type key string + const k1 key = "key1" + const k2 key = "key2" + const k3 key = "key3" + + RegisterContextKey(k1) + RegisterContextKey(k2) + + ctx := context.WithValue(context.Background(), k1, "value1") + ctx = context.WithValue(ctx, k2, "value2") + ctx = context.WithValue(ctx, k3, "value3") // Not registered + + frozen := FreezeContext(ctx) + + assert.Equal(t, "value1", frozen[k1], "should have captured k1") + assert.Equal(t, "value2", frozen[k2], "should have captured k2") + assert.Zero(t, frozen[k3], "k3 should not be captured") + + thawed := ThawContext(frozen) + assert.Equal(t, "value1", thawed.Value(k1), "should have k1 after thaw") + assert.Equal(t, "value2", thawed.Value(k2), "should have k2 after thaw") + assert.Nil(t, thawed.Value(k3), "Thawed context should not have k3") + + ctx2 := context.WithValue(context.Background(), k3, "value3_new") + merged := MergeContext(ctx2, frozen) + assert.Equal(t, "value1", merged.Value(k1), "should have k1 from frozen") + assert.Equal(t, "value2", merged.Value(k2), "should have k2 from frozen") + assert.Equal(t, "value3_new", merged.Value(k3), "should have k3 from parent") +} diff --git a/docs/ADD_NEW_EXCHANGE.md b/docs/ADD_NEW_EXCHANGE.md index 38c79f5bad4..140aa5e074e 100644 --- a/docs/ADD_NEW_EXCHANGE.md +++ b/docs/ADD_NEW_EXCHANGE.md @@ -693,19 +693,17 @@ func (e *Exchange) WsConnect() error { // KeepAuthKeyAlive will continuously send messages to // keep the WS auth key active func (e *Exchange) KeepAuthKeyAlive(ctx context.Context) { - e.Websocket.Wg.Add(1) defer e.Websocket.Wg.Done() - ticks := time.NewTicker(time.Minute * 30) for { select { case <-e.Websocket.ShutdownC: - ticks.Stop() return - case <-ticks.C: - err := e.MaintainWsAuthStreamKey(ctx) - if err != nil { - e.Websocket.DataHandler <- err - log.Warnf(log.ExchangeSys, "%s - Unable to renew auth websocket token, may experience shutdown", e.Name) + case <-time.After(time.Minute * 30): + if err := e.MaintainWsAuthStreamKey(ctx); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } + log.Warnf(log.ExchangeSys, "%s %s: Unable to renew auth websocket token, may experience shutdown", e.Name, e.Websocket.Conn.GetURL()) } } } @@ -817,9 +815,7 @@ Run gocryptotrader with the following settings enabled in config ```go // wsReadData gets and passes on websocket messages for processing func (e *Exchange) wsReadData() { - e.Websocket.Wg.Add(1) defer e.Websocket.Wg.Done() - for { select { case <-e.Websocket.ShutdownC: @@ -829,10 +825,10 @@ func (e *Exchange) wsReadData() { if resp.Raw == nil { return } - - err := e.wsHandleData(resp.Raw) - if err != nil { - e.Websocket.DataHandler <- err + if err := e.wsHandleData(ctx, resp.Raw); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } } } @@ -875,7 +871,7 @@ If a suitable struct does not exist in wshandler, wrapper types are the next pre if err := json.Unmarshal(respRaw, &resultData);err != nil { return err } - e.Websocket.DataHandler <- &ticker.Price{ + return e.Websocket.DataHandler.Send(ctx, &ticker.Price{ ExchangeName: e.Name, Bid: resultData.Ticker.Bid, Ask: resultData.Ticker.Ask, @@ -883,7 +879,7 @@ If a suitable struct does not exist in wshandler, wrapper types are the next pre LastUpdated: resultData.Ticker.Time, Pair: p, AssetType: a, - } + }) } ``` @@ -896,7 +892,7 @@ If neither of those provide a suitable struct to store the data in, the data can if err != nil { return err } - e.Websocket.DataHandler <- resultData.FillsData + return e.Websocket.DataHandler.Send(ctx, resultData.FillsData) ``` - Data Handling can be tested offline similar to the following example: diff --git a/engine/rpcserver.go b/engine/rpcserver.go index 24cd3d483c4..6f5771785cb 100644 --- a/engine/rpcserver.go +++ b/engine/rpcserver.go @@ -2947,7 +2947,7 @@ func (s *RPCServer) WebsocketGetInfo(_ context.Context, r *gctrpc.WebsocketGetIn } // WebsocketSetEnabled enables or disables the websocket client -func (s *RPCServer) WebsocketSetEnabled(_ context.Context, r *gctrpc.WebsocketSetEnabledRequest) (*gctrpc.GenericResponse, error) { +func (s *RPCServer) WebsocketSetEnabled(ctx context.Context, r *gctrpc.WebsocketSetEnabledRequest) (*gctrpc.GenericResponse, error) { exch, err := s.GetExchangeByName(r.Exchange) if err != nil { return nil, err @@ -2964,11 +2964,9 @@ func (s *RPCServer) WebsocketSetEnabled(_ context.Context, r *gctrpc.WebsocketSe } if r.Enable { - err = w.Enable() - if err != nil { + if err := w.Enable(context.WithoutCancel(ctx)); err != nil { return nil, err } - exchCfg.Features.Enabled.Websocket = true return &gctrpc.GenericResponse{Status: MsgStatusSuccess, Data: "websocket enabled"}, nil } @@ -3013,7 +3011,7 @@ func (s *RPCServer) WebsocketGetSubscriptions(_ context.Context, r *gctrpc.Webso } // WebsocketSetProxy sets client websocket connection proxy -func (s *RPCServer) WebsocketSetProxy(_ context.Context, r *gctrpc.WebsocketSetProxyRequest) (*gctrpc.GenericResponse, error) { +func (s *RPCServer) WebsocketSetProxy(ctx context.Context, r *gctrpc.WebsocketSetProxyRequest) (*gctrpc.GenericResponse, error) { exch, err := s.GetExchangeByName(r.Exchange) if err != nil { return nil, err @@ -3024,15 +3022,12 @@ func (s *RPCServer) WebsocketSetProxy(_ context.Context, r *gctrpc.WebsocketSetP return nil, fmt.Errorf("websocket not supported for exchange %s", r.Exchange) } - err = w.SetProxyAddress(r.Proxy) - if err != nil { + if err := w.SetProxyAddress(context.WithoutCancel(ctx), r.Proxy); err != nil { return nil, err } return &gctrpc.GenericResponse{ Status: MsgStatusSuccess, - Data: fmt.Sprintf("new proxy has been set [%s] for %s websocket connection", - r.Exchange, - r.Proxy), + Data: fmt.Sprintf("new proxy has been set [%s] for %s websocket connection", r.Exchange, r.Proxy), }, nil } diff --git a/engine/websocketroutine_manager.go b/engine/websocketroutine_manager.go index 0c02af3c872..59ac501ff05 100644 --- a/engine/websocketroutine_manager.go +++ b/engine/websocketroutine_manager.go @@ -1,6 +1,7 @@ package engine import ( + "context" "fmt" "sync" "sync/atomic" @@ -139,7 +140,7 @@ func (m *WebsocketRoutineManager) websocketRoutine() { log.Errorf(log.WebsocketMgr, "%v", err) } - if err := ws.Connect(); err != nil { + if err := ws.Connect(context.TODO()); err != nil { log.Errorf(log.WebsocketMgr, "%v", err) } }) @@ -167,14 +168,13 @@ func (m *WebsocketRoutineManager) websocketDataReceiver(ws *websocket.Manager) e select { case <-m.shutdown: return - case data := <-ws.ToRoutine: - if data == nil { + case payload := <-ws.DataHandler.C: + if payload.Data == nil { log.Errorf(log.WebsocketMgr, "exchange %s nil data sent to websocket", ws.GetName()) } m.mu.RLock() for x := range m.dataHandlers { - err := m.dataHandlers[x](ws.GetName(), data) - if err != nil { + if err := m.dataHandlers[x](ws.GetName(), payload.Data); err != nil { log.Errorln(log.WebsocketMgr, err) } } diff --git a/engine/websocketroutine_manager_test.go b/engine/websocketroutine_manager_test.go index 6b00092733d..21ab43e3490 100644 --- a/engine/websocketroutine_manager_test.go +++ b/engine/websocketroutine_manager_test.go @@ -258,16 +258,18 @@ func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) { } mock := websocket.NewManager() - mock.ToRoutine = make(chan any) m.state = readyState err = m.websocketDataReceiver(mock) if err != nil { t.Fatal(err) } - mock.ToRoutine <- nil - mock.ToRoutine <- 1336 - mock.ToRoutine <- "intercepted" + err = mock.DataHandler.Send(t.Context(), nil) + require.NoError(t, err) + err = mock.DataHandler.Send(t.Context(), 1336) + require.NoError(t, err) + err = mock.DataHandler.Send(t.Context(), "intercepted") + require.NoError(t, err) if r := <-dataChan; r != "intercepted" { t.Fatal("unexpected value received") diff --git a/exchange/stream/relay.go b/exchange/stream/relay.go new file mode 100644 index 00000000000..5225e4a9a67 --- /dev/null +++ b/exchange/stream/relay.go @@ -0,0 +1,48 @@ +package stream + +import ( + "context" + "errors" + "fmt" + + "github.com/thrasher-corp/gocryptotrader/common" +) + +var errChannelBufferFull = errors.New("channel buffer is full") + +// Relay defines a channel relay for messages +type Relay struct { + C <-chan Payload + comm chan Payload +} + +// Payload represents a relayed message with a context +type Payload struct { + Ctx common.FrozenContext + Data any +} + +// NewRelay creates a new Relay instance with a specified buffer size +func NewRelay(buffer uint) *Relay { + if buffer == 0 { + panic("buffer size must be greater than 0") + } + comm := make(chan Payload, buffer) + return &Relay{comm: comm, C: comm} +} + +// Send sends a message to the channel receiver +// This is non-blocking and returns an error if the channel buffer is full +func (r *Relay) Send(ctx context.Context, data any) error { + select { + case r.comm <- Payload{Ctx: common.FreezeContext(ctx), Data: data}: + return nil + default: + return fmt.Errorf("%w: failed to relay <%T>", errChannelBufferFull, data) + } +} + +// Close closes the relay channel +func (r *Relay) Close() { + close(r.comm) +} diff --git a/exchange/stream/relay_test.go b/exchange/stream/relay_test.go new file mode 100644 index 00000000000..9224e0babb6 --- /dev/null +++ b/exchange/stream/relay_test.go @@ -0,0 +1,43 @@ +package stream + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewRelay(t *testing.T) { + t.Parallel() + assert.Panics(t, func() { NewRelay(0) }, "buffer size should be greater than 0") + r := NewRelay(5) + require.NotNil(t, r) + assert.Equal(t, 5, cap(r.comm)) +} + +func TestSend(t *testing.T) { + t.Parallel() + r := NewRelay(1) + require.NotNil(t, r) + assert.NoError(t, r.Send(t.Context(), "test")) + assert.ErrorIs(t, r.Send(t.Context(), "overflow"), errChannelBufferFull) +} + +func TestRead(t *testing.T) { + t.Parallel() + r := NewRelay(1) + require.NotNil(t, r) + require.Empty(t, r.C) + assert.NoError(t, r.Send(t.Context(), "test")) + require.Len(t, r.C, 1) + assert.Equal(t, "test", (<-r.C).Data) +} + +func TestClose(t *testing.T) { + t.Parallel() + r := NewRelay(1) + require.NotNil(t, r) + r.Close() + _, ok := <-r.C + assert.False(t, ok) +} diff --git a/exchange/websocket/buffer/buffer.go b/exchange/websocket/buffer/buffer.go index 527bf211ca3..baee4707ad2 100644 --- a/exchange/websocket/buffer/buffer.go +++ b/exchange/websocket/buffer/buffer.go @@ -2,13 +2,16 @@ package buffer import ( "cmp" + "context" "errors" "fmt" "slices" + "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/key" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchange/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" ) @@ -16,26 +19,16 @@ import ( const packageError = "websocket orderbook buffer error: %w" var ( - errExchangeConfigNil = errors.New("exchange config is nil") - errBufferConfigNil = errors.New("buffer config is nil") - errUnsetDataHandler = errors.New("datahandler unset") errIssueBufferEnabledButNoLimit = errors.New("buffer enabled but no limit set") errOrderbookFlushed = errors.New("orderbook flushed") ) // Setup sets private variables -func (o *Orderbook) Setup(exchangeConfig *config.Exchange, c *Config, dataHandler chan<- any) error { - if exchangeConfig == nil { // exchange config fields are checked in websocket package prior to calling this, so further checks are not needed - return fmt.Errorf(packageError, errExchangeConfigNil) - } - if c == nil { - return fmt.Errorf(packageError, errBufferConfigNil) - } - if dataHandler == nil { - return fmt.Errorf(packageError, errUnsetDataHandler) +func (o *Orderbook) Setup(exchangeConfig *config.Exchange, c *Config, dataHandler *stream.Relay) error { + if err := common.NilGuard(exchangeConfig, c, dataHandler); err != nil { + return err } - if exchangeConfig.Orderbook.WebsocketBufferEnabled && - exchangeConfig.Orderbook.WebsocketBufferLimit < 1 { + if exchangeConfig.Orderbook.WebsocketBufferEnabled && exchangeConfig.Orderbook.WebsocketBufferLimit < 1 { return fmt.Errorf(packageError, errIssueBufferEnabledButNoLimit) } @@ -54,6 +47,7 @@ func (o *Orderbook) Setup(exchangeConfig *config.Exchange, c *Config, dataHandle // LoadSnapshot loads initial snapshot of orderbook data from websocket func (o *Orderbook) LoadSnapshot(book *orderbook.Book) error { + ctx := context.TODO() if err := book.Validate(); err != nil { return err } @@ -81,8 +75,7 @@ func (o *Orderbook) LoadSnapshot(book *orderbook.Book) error { } holder.ob.Publish() - o.dataHandler <- holder.ob - return nil + return o.dataHandler.Send(ctx, holder.ob) } // Update updates a stored pointer to an orderbook.Depth struct containing bid and ask Tranches, this switches between @@ -107,8 +100,7 @@ func (o *Orderbook) Update(u *orderbook.Update) error { // Publish all state changes, disregarding verbosity or sync requirements. holder.ob.Publish() - o.dataHandler <- holder.ob - return nil + return o.dataHandler.Send(context.TODO(), holder.ob) } // processBufferUpdate stores update into buffer, when buffer at capacity as diff --git a/exchange/websocket/buffer/buffer_test.go b/exchange/websocket/buffer/buffer_test.go index dda797d0372..8a05f8c104b 100644 --- a/exchange/websocket/buffer/buffer_test.go +++ b/exchange/websocket/buffer/buffer_test.go @@ -12,6 +12,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/common/key" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchange/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" ) @@ -53,15 +54,15 @@ func createSnapshot(pair currency.Pair) (holder *Orderbook, asks, bids orderbook newBook := make(map[key.PairAsset]*orderbookHolder) - ch := make(chan any) - go func(<-chan any) { // reader - for range ch { + relay := stream.NewRelay(10) + go func(relay *stream.Relay) { // reader + for range relay.C { continue } - }(ch) + }(relay) holder = &Orderbook{ exchangeName: exchangeName, - dataHandler: ch, + dataHandler: relay, ob: newBook, } err = holder.LoadSnapshot(book) @@ -432,7 +433,7 @@ func TestRunSnapshotWithNoData(t *testing.T) { var obl Orderbook obl.ob = make(map[key.PairAsset]*orderbookHolder) - obl.dataHandler = make(chan any, 1) + obl.dataHandler = stream.NewRelay(1) var snapShot1 orderbook.Book snapShot1.Asset = asset.Spot snapShot1.Pair = cp @@ -449,7 +450,7 @@ func TestLoadSnapshot(t *testing.T) { require.NoError(t, err) var obl Orderbook - obl.dataHandler = make(chan any, 100) + obl.dataHandler = stream.NewRelay(100) obl.ob = make(map[key.PairAsset]*orderbookHolder) err = obl.LoadSnapshot(&orderbook.Book{Asks: orderbook.Levels{{Amount: 1}}, ValidateOrderbook: true}) @@ -502,7 +503,7 @@ func TestInsertingSnapShots(t *testing.T) { require.NoError(t, err) var holder Orderbook - holder.dataHandler = make(chan any, 100) + holder.dataHandler = stream.NewRelay(100) holder.ob = make(map[key.PairAsset]*orderbookHolder) var snapShot1 orderbook.Book snapShot1.Exchange = "WSORDERBOOKTEST1" @@ -698,18 +699,18 @@ func TestSetup(t *testing.T) { t.Parallel() w := Orderbook{} err := w.Setup(nil, nil, nil) - require.ErrorIs(t, err, errExchangeConfigNil) + require.ErrorIs(t, err, common.ErrNilPointer) exchangeConfig := &config.Exchange{} err = w.Setup(exchangeConfig, nil, nil) - require.ErrorIs(t, err, errBufferConfigNil) + require.ErrorIs(t, err, common.ErrNilPointer) bufferConf := &Config{} err = w.Setup(exchangeConfig, bufferConf, nil) - require.ErrorIs(t, err, errUnsetDataHandler) + require.ErrorIs(t, err, common.ErrNilPointer) exchangeConfig.Orderbook.WebsocketBufferEnabled = true - err = w.Setup(exchangeConfig, bufferConf, make(chan any)) + err = w.Setup(exchangeConfig, bufferConf, stream.NewRelay(1)) require.ErrorIs(t, err, errIssueBufferEnabledButNoLimit) exchangeConfig.Orderbook.WebsocketBufferLimit = 1337 @@ -717,7 +718,7 @@ func TestSetup(t *testing.T) { exchangeConfig.Name = "test" bufferConf.SortBuffer = true bufferConf.SortBufferByUpdateIDs = true - err = w.Setup(exchangeConfig, bufferConf, make(chan any)) + err = w.Setup(exchangeConfig, bufferConf, stream.NewRelay(1)) require.NoError(t, err) require.Equal(t, 1337, w.obBufferLimit) @@ -733,7 +734,7 @@ func TestInvalidateOrderbook(t *testing.T) { require.NoError(t, err) w := &Orderbook{} - err = w.Setup(&config.Exchange{Name: "test"}, &Config{}, make(chan any, 2)) + err = w.Setup(&config.Exchange{Name: "test"}, &Config{}, stream.NewRelay(2)) require.NoError(t, err) var snapShot1 orderbook.Book diff --git a/exchange/websocket/buffer/buffer_types.go b/exchange/websocket/buffer/buffer_types.go index 142f06da679..b2ed1b36def 100644 --- a/exchange/websocket/buffer/buffer_types.go +++ b/exchange/websocket/buffer/buffer_types.go @@ -4,6 +4,7 @@ import ( "sync" "github.com/thrasher-corp/gocryptotrader/common/key" + "github.com/thrasher-corp/gocryptotrader/exchange/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" ) @@ -26,7 +27,7 @@ type Orderbook struct { sortBuffer bool sortBufferByUpdateIDs bool // When timestamps aren't provided, an id can help sort exchangeName string - dataHandler chan<- any + dataHandler *stream.Relay verbose bool m sync.RWMutex diff --git a/exchange/websocket/manager.go b/exchange/websocket/manager.go index 853c8665781..1956e98dce0 100644 --- a/exchange/websocket/manager.go +++ b/exchange/websocket/manager.go @@ -12,6 +12,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/config" + "github.com/thrasher-corp/gocryptotrader/exchange/stream" "github.com/thrasher-corp/gocryptotrader/exchange/websocket/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/fill" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" @@ -100,8 +101,7 @@ type Manager struct { Unsubscriber func(subscription.List) error GenerateSubs func() (subscription.List, error) useMultiConnectionManagement bool - DataHandler chan any - ToRoutine chan any + DataHandler *stream.Relay Match *Match ShutdownC chan struct{} Wg sync.WaitGroup @@ -175,8 +175,7 @@ func SetupGlobalReporter(r Reporter) { // NewManager initialises the websocket struct func NewManager() *Manager { return &Manager{ - DataHandler: make(chan any, jobBuffer), - ToRoutine: make(chan any, jobBuffer), + DataHandler: stream.NewRelay(jobBuffer), ShutdownC: make(chan struct{}), TrafficAlert: make(chan struct{}, 1), // ReadMessageErrors is buffered for an edge case when `Connect` fails @@ -407,13 +406,13 @@ func (m *Manager) getConnectionFromSetup(c *ConnectionSetup) *connection { // Connect initiates a websocket connection by using a package defined connection // function -func (m *Manager) Connect() error { +func (m *Manager) Connect(ctx context.Context) error { m.m.Lock() defer m.m.Unlock() - return m.connect() + return m.connect(ctx) } -func (m *Manager) connect() error { +func (m *Manager) connect(ctx context.Context) error { if !m.IsEnabled() { return ErrWebsocketNotEnabled } @@ -431,9 +430,8 @@ func (m *Manager) connect() error { m.setState(connectingState) - m.Wg.Add(2) - go m.monitorFrame(&m.Wg, m.monitorData) - go m.monitorFrame(&m.Wg, m.monitorTraffic) + m.Wg.Add(1) + go m.monitorFrame(ctx, &m.Wg, m.monitorTraffic) if !m.useMultiConnectionManagement { if m.connector == nil { @@ -448,7 +446,7 @@ func (m *Manager) connect() error { if m.connectionMonitorRunning.CompareAndSwap(false, true) { // This oversees all connections and does not need to be part of wait group management. - go m.monitorFrame(nil, m.monitorConnection) + go m.monitorFrame(ctx, nil, m.monitorConnection) } subs, err := m.GenerateSubs() // regenerate state on new connection @@ -456,7 +454,7 @@ func (m *Manager) connect() error { return fmt.Errorf("%s websocket: %w", m.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) } if len(subs) != 0 { - if err := m.SubscribeToChannels(nil, subs); err != nil { + if err := m.SubscribeToChannels(ctx, nil, subs); err != nil { return err } @@ -521,7 +519,7 @@ func (m *Manager) connect() error { conn := m.getConnectionFromSetup(m.connectionManager[i].setup) - if err := m.connectionManager[i].setup.Connector(context.TODO(), conn); err != nil { + if err := m.connectionManager[i].setup.Connector(ctx, conn); err != nil { multiConnectFatalError = fmt.Errorf("%v Error connecting %w", m.exchangeName, err) break } @@ -535,10 +533,10 @@ func (m *Manager) connect() error { m.connectionManager[i].connection = conn m.Wg.Add(1) - go m.Reader(context.TODO(), conn, m.connectionManager[i].setup.Handler) + go m.Reader(ctx, conn, m.connectionManager[i].setup.Handler) if m.connectionManager[i].setup.Authenticate != nil && m.CanUseAuthenticatedEndpoints() { - if err := m.connectionManager[i].setup.Authenticate(context.TODO(), conn); err != nil { + if err := m.connectionManager[i].setup.Authenticate(ctx, conn); err != nil { multiConnectFatalError = fmt.Errorf("%s websocket: [conn:%d] [URL:%s] failed to authenticate %w", m.exchangeName, i+1, conn.URL, err) break } @@ -548,7 +546,7 @@ func (m *Manager) connect() error { continue } - if err := m.connectionManager[i].setup.Subscriber(context.TODO(), conn, subs); err != nil { + if err := m.connectionManager[i].setup.Subscriber(ctx, conn, subs); err != nil { subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v Error subscribing %w", m.exchangeName, err)) continue } @@ -596,7 +594,7 @@ func (m *Manager) connect() error { if m.connectionMonitorRunning.CompareAndSwap(false, true) { // This oversees all connections and does not need to be part of wait group management. - go m.monitorFrame(nil, m.monitorConnection) + go m.monitorFrame(ctx, nil, m.monitorConnection) } return subscriptionError @@ -614,13 +612,13 @@ func (m *Manager) Disable() error { } // Enable enables the exchange websocket protocol -func (m *Manager) Enable() error { +func (m *Manager) Enable(ctx context.Context) error { if m.IsConnected() || m.IsEnabled() { return fmt.Errorf("%s %w", m.exchangeName, ErrWebsocketAlreadyEnabled) } m.setEnabled(true) - return m.Connect() + return m.Connect(ctx) } // Shutdown attempts to shut down a websocket connection and associated routines @@ -809,7 +807,7 @@ func (m *Manager) GetWebsocketURL() string { } // SetProxyAddress sets websocket proxy address -func (m *Manager) SetProxyAddress(proxyAddr string) error { +func (m *Manager) SetProxyAddress(ctx context.Context, proxyAddr string) error { m.m.Lock() defer m.m.Unlock() if proxyAddr != "" { @@ -846,7 +844,7 @@ func (m *Manager) SetProxyAddress(proxyAddr string) error { if err := m.shutdown(); err != nil { return err } - return m.connect() + return m.connect(ctx) } // GetProxyAddress returns the current websocket proxy @@ -890,7 +888,10 @@ func (m *Manager) Reader(ctx context.Context, conn Connection, handler func(ctx return // Connection has been closed } if err := handler(ctx, conn, resp.Raw); err != nil { - m.DataHandler <- fmt.Errorf("connection URL:[%v] error: %w", conn.GetURL(), err) + err = fmt.Errorf("connection URL:[%v] error: %w", conn.GetURL(), err) + if errSend := m.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s: %s %s", m.exchangeName, errSend, err) + } } } } @@ -906,16 +907,16 @@ func drain(ch <-chan error) { } // ClosureFrame is a closure function that wraps monitoring variables with observer, if the return is true the frame will exit -type ClosureFrame func() func() bool +type ClosureFrame func(ctx context.Context) func() bool // monitorFrame monitors a specific websocket component or critical system. It will exit if the observer returns true // This is used for monitoring data throughput, connection status and other critical websocket components. The waitgroup // is optional and is used to signal when the monitor has finished. -func (m *Manager) monitorFrame(wg *sync.WaitGroup, fn ClosureFrame) { +func (m *Manager) monitorFrame(ctx context.Context, wg *sync.WaitGroup, fn ClosureFrame) { if wg != nil { defer m.Wg.Done() } - observe := fn() + observe := fn(ctx) for { if observe() { return @@ -923,43 +924,14 @@ func (m *Manager) monitorFrame(wg *sync.WaitGroup, fn ClosureFrame) { } } -// monitorData monitors data throughput and logs if there is a back log of data -func (m *Manager) monitorData() func() bool { - dropped := 0 - return func() bool { return m.observeData(&dropped) } -} - -// observeData observes data throughput and logs if there is a back log of data -func (m *Manager) observeData(dropped *int) (exit bool) { - select { - case <-m.ShutdownC: - return true - case d := <-m.DataHandler: - select { - case m.ToRoutine <- d: - if *dropped != 0 { - log.Infof(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer recovered; %d messages were dropped", m.exchangeName, dropped) - *dropped = 0 - } - default: - if *dropped == 0 { - // If this becomes prone to flapping we could drain the buffer, but that's extreme and we'd like to avoid it if possible - log.Warnf(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer full; dropping messages", m.exchangeName) - } - *dropped++ - } - return false - } -} - // monitorConnection monitors the connection and attempts to reconnect if the connection is lost -func (m *Manager) monitorConnection() func() bool { +func (m *Manager) monitorConnection(ctx context.Context) func() bool { timer := time.NewTimer(m.connectionMonitorDelay) - return func() bool { return m.observeConnection(timer) } + return func() bool { return m.observeConnection(ctx, timer) } } // observeConnection observes the connection and attempts to reconnect if the connection is lost -func (m *Manager) observeConnection(t *time.Timer) (exit bool) { +func (m *Manager) observeConnection(ctx context.Context, t *time.Timer) (exit bool) { select { case err := <-m.ReadMessageErrors: if errors.Is(err, errConnectionFault) { @@ -972,11 +944,13 @@ func (m *Manager) observeConnection(t *time.Timer) (exit bool) { } // Speedier reconnection, instead of waiting for the next cycle. if m.IsEnabled() && (!m.IsConnected() && !m.IsConnecting()) { - if connectErr := m.Connect(); connectErr != nil { + if connectErr := m.Connect(ctx); connectErr != nil { log.Errorln(log.WebsocketMgr, connectErr) } } - m.DataHandler <- err // hand over the error to the data handler (shutdown and reconnection is priority) + if err := m.DataHandler.Send(ctx, err); err != nil { + log.Errorf(log.WebsocketMgr, "%v websocket: connectionMonitor data handler err: %s", m.exchangeName, err) + } case <-t.C: if m.verbose { log.Debugf(log.WebsocketMgr, "%v websocket: running connection monitor cycle", m.exchangeName) @@ -998,7 +972,7 @@ func (m *Manager) observeConnection(t *time.Timer) (exit bool) { return true } if !m.IsConnecting() && !m.IsConnected() { - err := m.Connect() + err := m.Connect(ctx) if err != nil { log.Errorln(log.WebsocketMgr, err) } @@ -1010,24 +984,22 @@ func (m *Manager) observeConnection(t *time.Timer) (exit bool) { // monitorTraffic monitors to see if there has been traffic within the trafficTimeout time window. If there is no traffic // the connection is shutdown and will be reconnected by the connectionMonitor routine. -func (m *Manager) monitorTraffic() func() bool { - timer := time.NewTimer(m.trafficTimeout) - return func() bool { return m.observeTraffic(timer) } +func (m *Manager) monitorTraffic(context.Context) func() bool { + return func() bool { return m.observeTraffic(m.trafficTimeout) } } -func (m *Manager) observeTraffic(t *time.Timer) bool { +func (m *Manager) observeTraffic(timeout time.Duration) bool { select { case <-m.ShutdownC: if m.verbose { log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", m.exchangeName) } - case <-t.C: + case <-time.After(timeout): if m.IsConnecting() || signalReceived(m.TrafficAlert) { - t.Reset(m.trafficTimeout) return false } if m.verbose { - log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", m.exchangeName, m.trafficTimeout) + log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", m.exchangeName, timeout) } if m.IsConnected() { go func() { // Without this the m.Shutdown() call below will deadlock @@ -1037,7 +1009,6 @@ func (m *Manager) observeTraffic(t *time.Timer) bool { }() } } - t.Stop() return true } diff --git a/exchange/websocket/manager_test.go b/exchange/websocket/manager_test.go index 15217caa886..6f1c2017b88 100644 --- a/exchange/websocket/manager_test.go +++ b/exchange/websocket/manager_test.go @@ -22,6 +22,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" + "github.com/thrasher-corp/gocryptotrader/exchange/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" @@ -98,7 +99,7 @@ func TestSetup(t *testing.T) { err := w.Setup(nil) assert.ErrorContains(t, err, "nil pointer: *websocket.Manager") - w = &Manager{DataHandler: make(chan any)} + w = &Manager{DataHandler: stream.NewRelay(1)} err = w.Setup(nil) assert.ErrorContains(t, err, "nil pointer: *websocket.ManagerSetup") @@ -167,24 +168,26 @@ func TestConnectionMessageErrors(t *testing.T) { t.Parallel() wsWrong := &Manager{} wsWrong.connector = func() error { return nil } - err := wsWrong.Connect() - assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "Connect should error correctly") + + wsWrong.DataHandler = stream.NewRelay(1) + err := wsWrong.Connect(t.Context()) + require.ErrorIs(t, err, ErrWebsocketNotEnabled, "Connect must error correctly") wsWrong.setEnabled(true) wsWrong.setState(connectingState) - err = wsWrong.Connect() - assert.ErrorIs(t, err, errAlreadyReconnecting, "Connect should error correctly") + err = wsWrong.Connect(t.Context()) + require.ErrorIs(t, err, errAlreadyReconnecting, "Connect must error correctly") wsWrong.setState(disconnectedState) - err = wsWrong.Connect() - assert.ErrorIs(t, err, common.ErrNilPointer, "Connect should get a nil pointer error") - assert.ErrorContains(t, err, "subscriptions", "Connect should get a nil pointer error about subscriptions") + err = wsWrong.Connect(t.Context()) + require.ErrorIs(t, err, common.ErrNilPointer, "Connect must get a nil pointer error") + require.ErrorContains(t, err, "subscriptions", "Connect must get a nil pointer error about subscriptions") wsWrong.subscriptions = subscription.NewStore() wsWrong.setState(disconnectedState) wsWrong.connector = func() error { return errDastardlyReason } - err = wsWrong.Connect() - assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly") + err = wsWrong.Connect(t.Context()) + require.ErrorIs(t, err, errDastardlyReason, "Connect must error correctly") ws := NewManager() err = ws.Setup(newDefaultSetup()) @@ -192,7 +195,7 @@ func TestConnectionMessageErrors(t *testing.T) { ws.trafficTimeout = time.Minute ws.connector = connect - require.ErrorIs(t, ws.Connect(), ErrSubscriptionsNotAdded) + require.ErrorIs(t, ws.Connect(t.Context()), ErrSubscriptionsNotAdded) require.NoError(t, ws.Shutdown()) ws.Subscriber = func(subs subscription.List) error { @@ -203,13 +206,13 @@ func TestConnectionMessageErrors(t *testing.T) { } return nil } - require.NoError(t, ws.Connect(), "Connect must not error") + require.NoError(t, ws.Connect(t.Context()), "Connect must not error") checkToRoutineResult := func(t *testing.T) { t.Helper() - v, ok := <-ws.ToRoutine + v, ok := <-ws.DataHandler.C require.True(t, ok, "ToRoutine must not be closed on us") - switch err := v.(type) { + switch err := v.Data.(type) { case *gws.CloseError: assert.Equal(t, "SpecialText", err.Text, "Should get correct Close Error") case error: @@ -230,7 +233,7 @@ func TestConnectionMessageErrors(t *testing.T) { require.NoError(t, ws.Shutdown()) ws.useMultiConnectionManagement = true - err = ws.Connect() + err = ws.Connect(t.Context()) assert.ErrorIs(t, err, errNoPendingConnections, "Connect should error correctly") ws.useMultiConnectionManagement = true @@ -239,7 +242,7 @@ func TestConnectionMessageErrors(t *testing.T) { mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) defer mock.Close() ws.connectionManager = []*connectionWrapper{{setup: &ConnectionSetup{URL: "ws" + mock.URL[len("http"):] + "/ws"}}} - err = ws.Connect() + err = ws.Connect(t.Context()) require.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) ws.connectionManager[0].setup.Authenticate = func(context.Context, Connection) error { return errDastardlyReason } @@ -247,57 +250,57 @@ func TestConnectionMessageErrors(t *testing.T) { ws.connectionManager[0].setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, errDastardlyReason } - err = ws.Connect() + err = ws.Connect(t.Context()) require.ErrorIs(t, err, errDastardlyReason) ws.connectionManager[0].setup.GenerateSubscriptions = func() (subscription.List, error) { return subscription.List{{Channel: "test"}}, nil } - err = ws.Connect() + err = ws.Connect(t.Context()) require.ErrorIs(t, err, errNoConnectFunc) ws.connectionManager[0].setup.Connector = func(context.Context, Connection) error { return errDastardlyReason } - err = ws.Connect() + err = ws.Connect(t.Context()) require.ErrorIs(t, err, errWebsocketDataHandlerUnset) ws.connectionManager[0].setup.Handler = func(context.Context, Connection, []byte) error { return errDastardlyReason } - err = ws.Connect() + err = ws.Connect(t.Context()) require.ErrorIs(t, err, errWebsocketSubscriberUnset) ws.connectionManager[0].setup.Subscriber = func(context.Context, Connection, subscription.List) error { return errDastardlyReason } - err = ws.Connect() + err = ws.Connect(t.Context()) require.ErrorIs(t, err, errDastardlyReason) ws.connectionManager[0].setup.Connector = func(ctx context.Context, conn Connection) error { return conn.Dial(ctx, gws.DefaultDialer, nil) } - err = ws.Connect() + err = ws.Connect(t.Context()) require.ErrorIs(t, err, errDastardlyReason) ws.connectionManager[0].setup.Handler = func(context.Context, Connection, []byte) error { return errDastardlyReason } - err = ws.Connect() + err = ws.Connect(t.Context()) require.ErrorIs(t, err, errDastardlyReason) ws.connectionManager[0].setup.Subscriber = func(context.Context, Connection, subscription.List) error { return errDastardlyReason } ws.connectionManager[0].setup.Authenticate = nil - err = ws.Connect() + err = ws.Connect(t.Context()) require.ErrorIs(t, err, errDastardlyReason) require.NoError(t, ws.shutdown()) ws.connectionManager[0].setup.Subscriber = func(context.Context, Connection, subscription.List) error { return nil } - err = ws.Connect() + err = ws.Connect(t.Context()) require.ErrorIs(t, err, ErrSubscriptionsNotAdded) require.NoError(t, ws.shutdown()) @@ -305,7 +308,7 @@ func TestConnectionMessageErrors(t *testing.T) { ws.connectionManager[0].setup.Subscriber = func(context.Context, Connection, subscription.List) error { return ws.connectionManager[0].subscriptions.Add(&subscription.Subscription{Channel: "test"}) } - err = ws.Connect() + err = ws.Connect(t.Context()) require.NoError(t, err) err = ws.connectionManager[0].connection.SendRawMessage(t.Context(), request.Unset, gws.TextMessage, []byte("test")) @@ -320,7 +323,7 @@ func TestManager(t *testing.T) { ws := NewManager() - err := ws.SetProxyAddress("garbagio") + err := ws.SetProxyAddress(t.Context(), "garbagio") assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly") ws.setEnabled(true) @@ -340,24 +343,24 @@ func TestManager(t *testing.T) { ws.setEnabled(true) assert.True(t, ws.IsEnabled(), "Websocket should be enabled by setEnabled(true)") - err = ws.SetProxyAddress("https://192.168.0.1:1337") + err = ws.SetProxyAddress(t.Context(), "https://192.168.0.1:1337") assert.NoError(t, err, "SetProxyAddress should not error when not yet connected") ws.setState(connectedState) ws.connector = func() error { return errDastardlyReason } - err = ws.SetProxyAddress("https://192.168.0.1:1336") + err = ws.SetProxyAddress(t.Context(), "https://192.168.0.1:1336") assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Connect and error from there") - err = ws.SetProxyAddress("https://192.168.0.1:1336") + err = ws.SetProxyAddress(t.Context(), "https://192.168.0.1:1336") assert.ErrorIs(t, err, errSameProxyAddress, "SetProxyAddress should error correctly") // removing proxy - assert.NoError(t, ws.SetProxyAddress("")) + assert.NoError(t, ws.SetProxyAddress(t.Context(), "")) ws.setEnabled(true) // reinstate proxy - err = ws.SetProxyAddress("http://localhost:1337") + err = ws.SetProxyAddress(t.Context(), "http://localhost:1337") assert.NoError(t, err, "SetProxyAddress should not error") assert.Equal(t, "http://localhost:1337", ws.GetProxyAddress(), "GetProxyAddress should return correctly") assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly") @@ -369,7 +372,7 @@ func TestManager(t *testing.T) { ws.connector = func() error { return nil } - require.ErrorIs(t, ws.Connect(), ErrSubscriptionsNotAdded) + require.ErrorIs(t, ws.Connect(t.Context()), ErrSubscriptionsNotAdded) require.NoError(t, ws.Shutdown()) ws.Subscriber = func(subs subscription.List) error { @@ -380,7 +383,7 @@ func TestManager(t *testing.T) { } return nil } - assert.NoError(t, ws.Connect(), "Connect should not error") + assert.NoError(t, ws.Connect(t.Context()), "Connect should not error") ws.defaultURL = "ws://demos.kaazing.com/echo" ws.defaultURLAuth = "ws://demos.kaazing.com/echo" @@ -401,10 +404,10 @@ func TestManager(t *testing.T) { assert.NoError(t, err, "SetWebsocketURL should not error on reconnect") // -- initiate the reconnect which is usually handled by connection monitor - err = ws.Connect() + err = ws.Connect(t.Context()) assert.NoError(t, err, "ReConnect called manually should not error") - err = ws.Connect() + err = ws.Connect(t.Context()) assert.ErrorIs(t, err, errAlreadyConnected, "ReConnect should error when already connected") err = ws.Shutdown() @@ -414,7 +417,7 @@ func TestManager(t *testing.T) { ws.useMultiConnectionManagement = true ws.connectionManager = []*connectionWrapper{{setup: &ConnectionSetup{URL: "ws://demos.kaazing.com/echo"}, connection: &connection{}}} - err = ws.SetProxyAddress("https://192.168.0.1:1337") + err = ws.SetProxyAddress(t.Context(), "https://192.168.0.1:1337") require.NoError(t, err) } @@ -624,9 +627,9 @@ type reporter struct { t time.Duration } -func (r *reporter) Latency(name string, message []byte, t time.Duration) { +func (r *reporter) Latency(name string, payload []byte, t time.Duration) { r.name = name - r.msg = message + r.msg = payload r.t = t } @@ -821,11 +824,11 @@ func TestFlushChannels(t *testing.T) { // Enabled pairs/setup system dodgyWs := Manager{} - err := dodgyWs.FlushChannels() + err := dodgyWs.FlushChannels(t.Context()) assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "FlushChannels should error correctly") dodgyWs.setEnabled(true) - err = dodgyWs.FlushChannels() + err = dodgyWs.FlushChannels(t.Context()) assert.ErrorIs(t, err, ErrNotConnected, "FlushChannels should error correctly") newgen := GenSubs{EnabledPairs: []currency.Pair{ @@ -852,7 +855,7 @@ func TestFlushChannels(t *testing.T) { newgen.EnabledPairs = []currency.Pair{currency.NewPair(currency.BTC, currency.AUD)} w.GenerateSubs = func() (subscription.List, error) { return subscription.List{{Channel: "test"}}, nil } - require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotAdded, "FlushChannels must error correctly on no subscriptions added") + require.ErrorIs(t, w.FlushChannels(t.Context()), ErrSubscriptionsNotAdded, "FlushChannels must error correctly on no subscriptions added") w.Subscriber = func(subs subscription.List) error { for _, sub := range subs { @@ -863,15 +866,15 @@ func TestFlushChannels(t *testing.T) { return nil } - require.NoError(t, w.FlushChannels(), "FlushChannels must not error") + require.NoError(t, w.FlushChannels(t.Context()), "FlushChannels must not error") w.GenerateSubs = func() (subscription.List, error) { return nil, errDastardlyReason } // error on generateSubs - err = w.FlushChannels() // error on full subscribeToChannels + err = w.FlushChannels(t.Context()) // error on full subscribeToChannels assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly on GenerateSubs") w.GenerateSubs = func() (subscription.List, error) { return nil, nil } // No subs to sub - require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotRemoved) + require.ErrorIs(t, w.FlushChannels(t.Context()), ErrSubscriptionsNotRemoved) w.Unsubscriber = func(subs subscription.List) error { for _, sub := range subs { @@ -881,13 +884,13 @@ func TestFlushChannels(t *testing.T) { } return nil } - assert.NoError(t, w.FlushChannels(), "FlushChannels should not error") + assert.NoError(t, w.FlushChannels(t.Context()), "FlushChannels should not error") w.GenerateSubs = newgen.generateSubs subs, err := w.GenerateSubs() require.NoError(t, err, "GenerateSubs must not error") require.NoError(t, w.AddSubscriptions(nil, subs...), "AddSubscriptions must not error") - err = w.FlushChannels() + err = w.FlushChannels(t.Context()) assert.NoError(t, err, "FlushChannels should not error") w.GenerateSubs = newgen.generateSubs @@ -905,11 +908,11 @@ func TestFlushChannels(t *testing.T) { }) require.NoError(t, err, "AddSubscription must not error") - err = w.FlushChannels() + err = w.FlushChannels(t.Context()) assert.NoError(t, err, "FlushChannels should not error") w.setState(connectedState) - err = w.FlushChannels() + err = w.FlushChannels(t.Context()) assert.NoError(t, err, "FlushChannels should not error") // Multi connection management @@ -930,27 +933,27 @@ func TestFlushChannels(t *testing.T) { Handler: func(context.Context, Connection, []byte) error { return nil }, } require.NoError(t, w.SetupNewConnection(amazingCandidate)) - require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotAdded, "Must error when no subscriptions are added to the subscription store") + require.ErrorIs(t, w.FlushChannels(t.Context()), ErrSubscriptionsNotAdded, "Must error when no subscriptions are added to the subscription store") w.connectionManager[0].setup.Subscriber = func(ctx context.Context, c Connection, s subscription.List) error { return currySimpleSubConn(w)(ctx, c, s) } - require.NoError(t, w.FlushChannels(), "FlushChannels must not error") + require.NoError(t, w.FlushChannels(t.Context()), "FlushChannels must not error") // Forces full connection cycle (shutdown, connect, subscribe). This will also start monitoring routines. w.features.Subscribe = false - require.NoError(t, w.FlushChannels(), "FlushChannels must not error") + require.NoError(t, w.FlushChannels(t.Context()), "FlushChannels must not error") // Unsubscribe what's already subscribed. No subscriptions left over, which then forces the shutdown and removal // of the connection from management. w.features.Subscribe = true w.connectionManager[0].setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil } - require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotRemoved, "Must error when no subscriptions are removed from subscription store") + require.ErrorIs(t, w.FlushChannels(t.Context()), ErrSubscriptionsNotRemoved, "Must error when no subscriptions are removed from subscription store") w.connectionManager[0].setup.Unsubscriber = func(ctx context.Context, c Connection, s subscription.List) error { return currySimpleUnsubConn(w)(ctx, c, s) } - require.NoError(t, w.FlushChannels(), "FlushChannels must not error") + require.NoError(t, w.FlushChannels(t.Context()), "FlushChannels must not error") } func TestDisable(t *testing.T) { @@ -969,8 +972,8 @@ func TestEnable(t *testing.T) { w.Subscriber = func(subscription.List) error { return nil } w.Unsubscriber = func(subscription.List) error { return nil } w.GenerateSubs = func() (subscription.List, error) { return nil, nil } - require.NoError(t, w.Enable(), "Enable must not error") - assert.ErrorIs(t, w.Enable(), ErrWebsocketAlreadyEnabled, "Enable should error correctly") + require.NoError(t, w.Enable(t.Context()), "Enable must not error") + assert.ErrorIs(t, w.Enable(t.Context()), ErrWebsocketAlreadyEnabled, "Enable should error correctly") } func TestSetupNewConnection(t *testing.T) { @@ -1162,39 +1165,13 @@ func TestDrain(t *testing.T) { func TestMonitorFrame(t *testing.T) { t.Parallel() ws := Manager{} - require.Panics(t, func() { ws.monitorFrame(nil, nil) }, "monitorFrame must panic on nil frame") - require.Panics(t, func() { ws.monitorFrame(nil, func() func() bool { return nil }) }, "monitorFrame must panic on nil function") + require.Panics(t, func() { ws.monitorFrame(t.Context(), nil, nil) }, "monitorFrame must panic on nil frame") + require.Panics(t, func() { ws.monitorFrame(t.Context(), nil, func(context.Context) func() bool { return nil }) }, "monitorFrame must panic on nil function") ws.Wg.Add(1) - ws.monitorFrame(&ws.Wg, func() func() bool { return func() bool { return true } }) + ws.monitorFrame(t.Context(), &ws.Wg, func(context.Context) func() bool { return func() bool { return true } }) ws.Wg.Wait() } -func TestMonitorData(t *testing.T) { - t.Parallel() - ws := Manager{ShutdownC: make(chan struct{}), DataHandler: make(chan any, 10)} - // Handle shutdown signal - close(ws.ShutdownC) - require.True(t, ws.observeData(nil)) - ws.ShutdownC = make(chan struct{}) - // Handle blockage of ToRoutine - go func() { ws.DataHandler <- nil }() - var dropped int - require.False(t, ws.observeData(&dropped)) - require.Equal(t, 1, dropped) - // Handle reinstate of ToRoutine functionality which will reset dropped counter - ws.ToRoutine = make(chan any, 10) - go func() { ws.DataHandler <- nil }() - require.False(t, ws.observeData(&dropped)) - require.Empty(t, dropped) - // Handle outer closure shell - innerShell := ws.monitorData() - go func() { ws.DataHandler <- nil }() - require.False(t, innerShell()) - // Handle shutdown signal - close(ws.ShutdownC) - require.True(t, innerShell()) -} - func TestMonitorConnection(t *testing.T) { t.Parallel() ws := Manager{verbose: true, ReadMessageErrors: make(chan error, 1), ShutdownC: make(chan struct{})} @@ -1202,7 +1179,7 @@ func TestMonitorConnection(t *testing.T) { timer := time.NewTimer(0) ws.setState(connectedState) ws.connectionMonitorRunning.Store(true) - require.True(t, ws.observeConnection(timer)) + require.True(t, ws.observeConnection(t.Context(), timer)) require.False(t, ws.connectionMonitorRunning.Load()) require.Equal(t, disconnectedState, ws.state.Load()) // Handle timer expired and everything is great, reset the timer. @@ -1210,22 +1187,22 @@ func TestMonitorConnection(t *testing.T) { ws.setState(connectedState) ws.connectionMonitorRunning.Store(true) timer = time.NewTimer(0) - require.False(t, ws.observeConnection(timer)) // Not shutting down + require.False(t, ws.observeConnection(t.Context(), timer)) // Not shutting down // Handle timer expired and for reason its not connected, so lets happily connect again. ws.setState(disconnectedState) - require.False(t, ws.observeConnection(timer)) // Connect is intentionally erroring + require.False(t, ws.observeConnection(t.Context(), timer)) // Connect is intentionally erroring // Handle error from a connection which will then trigger a reconnect ws.setState(connectedState) - ws.DataHandler = make(chan any, 1) + ws.DataHandler = stream.NewRelay(1) ws.ReadMessageErrors <- errConnectionFault timer = time.NewTimer(time.Second) - require.False(t, ws.observeConnection(timer)) - payload := <-ws.DataHandler - err, ok := payload.(error) + require.False(t, ws.observeConnection(t.Context(), timer)) + payload := <-ws.DataHandler.C + err, ok := payload.Data.(error) require.True(t, ok) require.ErrorIs(t, err, errConnectionFault) // Handle outta closure shell - innerShell := ws.monitorConnection() + innerShell := ws.monitorConnection(t.Context()) ws.setState(connectedState) ws.ReadMessageErrors <- errConnectionFault require.False(t, innerShell()) @@ -1236,27 +1213,23 @@ func TestMonitorTraffic(t *testing.T) { ws := Manager{verbose: true, ShutdownC: make(chan struct{}), TrafficAlert: make(chan struct{}, 1)} ws.Wg.Add(1) // Handle external shutdown signal - timer := time.NewTimer(time.Second) close(ws.ShutdownC) - require.True(t, ws.observeTraffic(timer)) + require.True(t, ws.observeTraffic(time.Second)) // Handle timer expired but system is connecting, so reset the timer ws.ShutdownC = make(chan struct{}) ws.setState(connectingState) - timer = time.NewTimer(0) - require.False(t, ws.observeTraffic(timer)) + require.False(t, ws.observeTraffic(0)) // Handle timer expired and system is connected and has traffic within time window ws.setState(connectedState) - timer = time.NewTimer(0) ws.TrafficAlert <- struct{}{} - require.False(t, ws.observeTraffic(timer)) + require.False(t, ws.observeTraffic(0)) // Handle timer expired and system is connected but no traffic within time window, causes shutdown to occur. - timer = time.NewTimer(0) - require.True(t, ws.observeTraffic(timer)) + require.True(t, ws.observeTraffic(0)) ws.Wg.Done() // Shutdown is done in a routine, so we need to wait for it to finish require.Eventually(t, func() bool { return disconnectedState == ws.state.Load() }, time.Second, time.Millisecond) // Handle outer closure shell - innerShell := ws.monitorTraffic() + innerShell := ws.monitorTraffic(t.Context()) ws.m.Lock() ws.ShutdownC = make(chan struct{}) ws.m.Unlock() diff --git a/exchange/websocket/subscriptions.go b/exchange/websocket/subscriptions.go index b3dc280d118..eb938efc7d6 100644 --- a/exchange/websocket/subscriptions.go +++ b/exchange/websocket/subscriptions.go @@ -24,13 +24,13 @@ var ( ) // UnsubscribeChannels unsubscribes from a list of websocket channel -func (m *Manager) UnsubscribeChannels(conn Connection, channels subscription.List) error { +func (m *Manager) UnsubscribeChannels(ctx context.Context, conn Connection, channels subscription.List) error { if len(channels) == 0 { return nil // No channels to unsubscribe from is not an error } if wrapper, ok := m.connections[conn]; ok && conn != nil { return m.unsubscribe(wrapper.subscriptions, channels, func(channels subscription.List) error { - return wrapper.setup.Unsubscriber(context.TODO(), conn, channels) + return wrapper.setup.Unsubscriber(ctx, conn, channels) }) } @@ -58,20 +58,20 @@ func (m *Manager) unsubscribe(store *subscription.Store, channels subscription.L // ResubscribeToChannel resubscribes to channel // Sets state to Resubscribing, and exchanges which want to maintain a lock on it can respect this state and not RemoveSubscription // Errors if subscription is already subscribing -func (m *Manager) ResubscribeToChannel(conn Connection, s *subscription.Subscription) error { +func (m *Manager) ResubscribeToChannel(ctx context.Context, conn Connection, s *subscription.Subscription) error { l := subscription.List{s} if err := s.SetState(subscription.ResubscribingState); err != nil { return fmt.Errorf("%w: %s", err, s) } - if err := m.UnsubscribeChannels(conn, l); err != nil { + if err := m.UnsubscribeChannels(ctx, conn, l); err != nil { return err } - return m.SubscribeToChannels(conn, l) + return m.SubscribeToChannels(ctx, conn, l) } // SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method // Errors are returned for duplicates or exceeding max Subscriptions -func (m *Manager) SubscribeToChannels(conn Connection, subs subscription.List) error { +func (m *Manager) SubscribeToChannels(ctx context.Context, conn Connection, subs subscription.List) error { if slices.Contains(subs, nil) { return fmt.Errorf("%w: List parameter contains an nil element", common.ErrNilPointer) } @@ -80,7 +80,7 @@ func (m *Manager) SubscribeToChannels(conn Connection, subs subscription.List) e } if wrapper, ok := m.connections[conn]; ok && conn != nil { - return wrapper.setup.Subscriber(context.TODO(), conn, subs) + return wrapper.setup.Subscriber(ctx, conn, subs) } if m.Subscriber == nil { @@ -255,7 +255,7 @@ func (m *Manager) checkSubscriptions(conn Connection, subs subscription.List) er } // FlushChannels flushes channel subscriptions when there is a pair/asset change -func (m *Manager) FlushChannels() error { +func (m *Manager) FlushChannels(ctx context.Context) error { if !m.IsEnabled() { return fmt.Errorf("%s %w", m.exchangeName, ErrWebsocketNotEnabled) } @@ -272,7 +272,7 @@ func (m *Manager) FlushChannels() error { if err := m.shutdown(); err != nil { return err } - return m.connect() + return m.connect(ctx) } if !m.useMultiConnectionManagement { @@ -280,7 +280,7 @@ func (m *Manager) FlushChannels() error { if err != nil { return err } - return m.updateChannelSubscriptions(nil, m.subscriptions, newSubs) + return m.updateChannelSubscriptions(ctx, nil, m.subscriptions, newSubs) } for x := range m.connectionManager { @@ -301,16 +301,16 @@ func (m *Manager) FlushChannels() error { // If there are subscriptions to subscribe to but no connection to subscribe to, establish a new connection. if m.connectionManager[x].connection == nil { conn := m.getConnectionFromSetup(m.connectionManager[x].setup) - if err := m.connectionManager[x].setup.Connector(context.TODO(), conn); err != nil { + if err := m.connectionManager[x].setup.Connector(ctx, conn); err != nil { return err } m.Wg.Add(1) - go m.Reader(context.TODO(), conn, m.connectionManager[x].setup.Handler) + go m.Reader(ctx, conn, m.connectionManager[x].setup.Handler) m.connections[conn] = m.connectionManager[x] m.connectionManager[x].connection = conn } - if err := m.updateChannelSubscriptions(m.connectionManager[x].connection, m.connectionManager[x].subscriptions, newSubs); err != nil { + if err := m.updateChannelSubscriptions(ctx, m.connectionManager[x].connection, m.connectionManager[x].subscriptions, newSubs); err != nil { return err } @@ -328,10 +328,10 @@ func (m *Manager) FlushChannels() error { // updateChannelSubscriptions subscribes or unsubscribes from channels and checks that the correct number of channels // have been subscribed to or unsubscribed from. -func (m *Manager) updateChannelSubscriptions(c Connection, store *subscription.Store, incoming subscription.List) error { +func (m *Manager) updateChannelSubscriptions(ctx context.Context, c Connection, store *subscription.Store, incoming subscription.List) error { subs, unsubs := store.Diff(incoming) if len(unsubs) != 0 { - if err := m.UnsubscribeChannels(c, unsubs); err != nil { + if err := m.UnsubscribeChannels(ctx, c, unsubs); err != nil { return err } @@ -340,7 +340,7 @@ func (m *Manager) updateChannelSubscriptions(c Connection, store *subscription.S } } if len(subs) != 0 { - if err := m.SubscribeToChannels(c, subs); err != nil { + if err := m.SubscribeToChannels(ctx, c, subs); err != nil { return err } diff --git a/exchange/websocket/subscriptions_test.go b/exchange/websocket/subscriptions_test.go index 7fc160e3cec..49bee49d620 100644 --- a/exchange/websocket/subscriptions_test.go +++ b/exchange/websocket/subscriptions_test.go @@ -21,11 +21,11 @@ func TestSubscribeUnsubscribe(t *testing.T) { subs, err := ws.GenerateSubs() require.NoError(t, err, "Generating test subscriptions must not error") - assert.ErrorIs(t, new(Manager).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function") - assert.NoError(t, ws.UnsubscribeChannels(nil, nil), "Unsubscribing from nil should not error") - assert.ErrorIs(t, ws.UnsubscribeChannels(nil, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") + assert.ErrorIs(t, new(Manager).UnsubscribeChannels(t.Context(), nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function") + assert.NoError(t, ws.UnsubscribeChannels(t.Context(), nil, nil), "Unsubscribing from nil should not error") + assert.ErrorIs(t, ws.UnsubscribeChannels(t.Context(), nil, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return") - assert.NoError(t, ws.SubscribeToChannels(nil, subs), "Basic Subscribing should not error") + assert.NoError(t, ws.SubscribeToChannels(t.Context(), nil, subs), "Basic Subscribing should not error") assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions") bySub := ws.GetSubscription(subscription.Subscription{Channel: "TestSub"}) if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") { @@ -43,14 +43,14 @@ func TestSubscribeUnsubscribe(t *testing.T) { } assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil") assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil") - assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") - assert.NoError(t, ws.SubscribeToChannels(nil, nil), "Subscribe to an nil List should not error") - assert.NoError(t, ws.UnsubscribeChannels(nil, subs), "Unsubscribing should not error") + assert.ErrorIs(t, ws.SubscribeToChannels(t.Context(), nil, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") + assert.NoError(t, ws.SubscribeToChannels(t.Context(), nil, nil), "Subscribe to an nil List should not error") + assert.NoError(t, ws.UnsubscribeChannels(t.Context(), nil, subs), "Unsubscribing should not error") ws.Subscriber = func(subscription.List) error { return errDastardlyReason } - assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber") + assert.ErrorIs(t, ws.SubscribeToChannels(t.Context(), nil, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber") - err = ws.SubscribeToChannels(nil, subscription.List{nil}) + err = ws.SubscribeToChannels(t.Context(), nil, subscription.List{nil}) assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription") multi := NewManager() @@ -79,15 +79,15 @@ func TestSubscribeUnsubscribe(t *testing.T) { subs, err = amazingCandidate.GenerateSubscriptions() require.NoError(t, err, "Generating test subscriptions must not error") - assert.ErrorIs(t, new(Manager).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function") - assert.ErrorIs(t, new(Manager).UnsubscribeChannels(amazingConn, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function") - assert.NoError(t, multi.UnsubscribeChannels(amazingConn, nil), "Unsubscribing from nil should not error") - assert.ErrorIs(t, multi.UnsubscribeChannels(amazingConn, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") + assert.ErrorIs(t, new(Manager).UnsubscribeChannels(t.Context(), nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function") + assert.ErrorIs(t, new(Manager).UnsubscribeChannels(t.Context(), amazingConn, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function") + assert.NoError(t, multi.UnsubscribeChannels(t.Context(), amazingConn, nil), "Unsubscribing from nil should not error") + assert.ErrorIs(t, multi.UnsubscribeChannels(t.Context(), amazingConn, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") assert.Nil(t, multi.GetSubscription(42), "GetSubscription on empty internal map should return") - assert.ErrorIs(t, multi.SubscribeToChannels(nil, subs), common.ErrNilPointer, "If no connection is set, Subscribe should error") + assert.ErrorIs(t, multi.SubscribeToChannels(t.Context(), nil, subs), common.ErrNilPointer, "If no connection is set, Subscribe should error") - assert.NoError(t, multi.SubscribeToChannels(amazingConn, subs), "Basic Subscribing should not error") + assert.NoError(t, multi.SubscribeToChannels(t.Context(), amazingConn, subs), "Basic Subscribing should not error") assert.Len(t, multi.GetSubscriptions(), 4, "Should have 4 subscriptions") bySub = multi.GetSubscription(subscription.Subscription{Channel: "TestSub"}) if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") { @@ -105,14 +105,14 @@ func TestSubscribeUnsubscribe(t *testing.T) { } assert.Nil(t, multi.GetSubscription(nil), "GetSubscription by nil should return nil") assert.Nil(t, multi.GetSubscription(45), "GetSubscription by invalid key should return nil") - assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") - assert.NoError(t, multi.SubscribeToChannels(amazingConn, nil), "Subscribe to an nil List should not error") - assert.NoError(t, multi.UnsubscribeChannels(amazingConn, subs), "Unsubscribing should not error") + assert.ErrorIs(t, multi.SubscribeToChannels(t.Context(), amazingConn, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") + assert.NoError(t, multi.SubscribeToChannels(t.Context(), amazingConn, nil), "Subscribe to an nil List should not error") + assert.NoError(t, multi.UnsubscribeChannels(t.Context(), amazingConn, subs), "Unsubscribing should not error") amazingCandidate.Subscriber = func(context.Context, Connection, subscription.List) error { return errDastardlyReason } - assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber") + assert.ErrorIs(t, multi.SubscribeToChannels(t.Context(), amazingConn, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber") - err = multi.SubscribeToChannels(amazingConn, subscription.List{nil}) + err = multi.SubscribeToChannels(t.Context(), amazingConn, subscription.List{nil}) assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription") } @@ -134,9 +134,9 @@ func TestResubscribe(t *testing.T) { channel := subscription.List{{Channel: "resubTest"}} - assert.ErrorIs(t, ws.ResubscribeToChannel(nil, channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet") - assert.NoError(t, ws.SubscribeToChannels(nil, channel), "Subscribe should not error") - assert.NoError(t, ws.ResubscribeToChannel(nil, channel[0]), "Resubscribe should not error now the channel is subscribed") + assert.ErrorIs(t, ws.ResubscribeToChannel(t.Context(), nil, channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet") + assert.NoError(t, ws.SubscribeToChannels(t.Context(), nil, channel), "Subscribe should not error") + assert.NoError(t, ws.ResubscribeToChannel(t.Context(), nil, channel[0]), "Resubscribe should not error now the channel is subscribed") } // TestSubscriptions tests adding, getting and removing subscriptions @@ -245,7 +245,7 @@ func TestUpdateChannelSubscriptions(t *testing.T) { ws := NewManager() store := subscription.NewStore() - err := ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}}) + err := ws.updateChannelSubscriptions(t.Context(), nil, store, subscription.List{{Channel: "test"}}) require.ErrorIs(t, err, common.ErrNilPointer) require.Zero(t, store.Len()) @@ -259,11 +259,11 @@ func TestUpdateChannelSubscriptions(t *testing.T) { } ws.subscriptions = store - err = ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}}) + err = ws.updateChannelSubscriptions(t.Context(), nil, store, subscription.List{{Channel: "test"}}) require.NoError(t, err) require.Equal(t, 1, store.Len()) - err = ws.updateChannelSubscriptions(nil, store, subscription.List{}) + err = ws.updateChannelSubscriptions(t.Context(), nil, store, subscription.List{}) require.ErrorIs(t, err, common.ErrNilPointer) ws.Unsubscriber = func(subs subscription.List) error { @@ -275,7 +275,7 @@ func TestUpdateChannelSubscriptions(t *testing.T) { return nil } - err = ws.updateChannelSubscriptions(nil, store, subscription.List{}) + err = ws.updateChannelSubscriptions(t.Context(), nil, store, subscription.List{}) require.NoError(t, err) require.Zero(t, store.Len()) } diff --git a/exchanges/binance/binance_test.go b/exchanges/binance/binance_test.go index 008b2e67908..bab81d7d985 100644 --- a/exchanges/binance/binance_test.go +++ b/exchanges/binance/binance_test.go @@ -1975,12 +1975,12 @@ func BenchmarkWsHandleData(bb *testing.B) { require.Len(bb, lines, 8) go func() { for { - <-e.Websocket.DataHandler + <-e.Websocket.DataHandler.C } }() for bb.Loop() { for x := range lines { - assert.NoError(bb, e.wsHandleData(lines[x])) + assert.NoError(bb, e.wsHandleData(bb.Context(), lines[x])) } } } @@ -2031,7 +2031,7 @@ func TestSubscribeBadResp(t *testing.T) { func TestWsTickerUpdate(t *testing.T) { t.Parallel() pressXToJSON := []byte(`{"stream":"btcusdt@ticker","data":{"e":"24hrTicker","E":1580254809477,"s":"ETHBTC","p":"420.97000000","P":"4.720","w":"9058.27981278","x":"8917.98000000","c":"9338.96000000","Q":"0.17246300","b":"9338.03000000","B":"0.18234600","a":"9339.70000000","A":"0.14097600","o":"8917.99000000","h":"9373.19000000","l":"8862.40000000","v":"72229.53692000","q":"654275356.16896672","O":1580168409456,"C":1580254809456,"F":235294268,"L":235894703,"n":600436}}`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -2063,7 +2063,7 @@ func TestWsKlineUpdate(t *testing.T) { "B": "123456" } }}`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -2085,7 +2085,7 @@ func TestWsTradeUpdate(t *testing.T) { "m": true, "M": true }}`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -2144,7 +2144,7 @@ func TestWsDepthUpdate(t *testing.T) { t.Fatal(err) } - if err := e.wsHandleData(update1); err != nil { + if err := e.wsHandleData(t.Context(), update1); err != nil { t.Fatal(err) } @@ -2180,7 +2180,7 @@ func TestWsDepthUpdate(t *testing.T) { ] }}`) - if err = e.wsHandleData(update2); err != nil { + if err = e.wsHandleData(t.Context(), update2); err != nil { t.Error(err) } @@ -2213,7 +2213,7 @@ func TestWsBalanceUpdate(t *testing.T) { "a": "BTC", "d": "100.00000000", "T": 1573200697068}}`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -2245,7 +2245,7 @@ func TestWsOCO(t *testing.T) { } ] }}`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -2534,16 +2534,21 @@ func TestWsOrderExecutionReport(t *testing.T) { Pair: currency.NewBTCUSDT(), } // empty the channel. otherwise mock_test will fail - for len(e.Websocket.DataHandler) > 0 { - <-e.Websocket.DataHandler +drain: + for { + select { + case <-e.Websocket.DataHandler.C: + default: + break drain + } } - err := e.wsHandleData(payload) + err := e.wsHandleData(t.Context(), payload) if err != nil { t.Fatal(err) } - res := <-e.Websocket.DataHandler - switch r := res.(type) { + res := <-e.Websocket.DataHandler.C + switch r := res.Data.(type) { case *order.Detail: if !reflect.DeepEqual(expectedResult, *r) { t.Errorf("Results do not match:\nexpected: %v\nreceived: %v", expectedResult, *r) @@ -2553,7 +2558,7 @@ func TestWsOrderExecutionReport(t *testing.T) { } payload = []byte(`{"stream":"jTfvpakT2yT0hVIo5gYWVihZhdM2PrBgJUZ5PyfZ4EVpCkx4Uoxk5timcrQc","data":{"e":"executionReport","E":1616633041556,"s":"BTCUSDT","c":"YeULctvPAnHj5HXCQo9Mob","S":"BUY","o":"LIMIT","f":"GTC","q":"0.00028600","p":"52436.85000000","P":"0.00000000","F":"0.00000000","g":-1,"C":"","x":"TRADE","X":"FILLED","r":"NONE","i":5341783271,"l":"0.00028600","z":"0.00028600","L":"52436.85000000","n":"0.00000029","N":"BTC","T":1616633041555,"t":726946523,"I":11390206312,"w":false,"m":false,"M":true,"O":1616633041555,"Z":"14.99693910","Y":"14.99693910","Q":"0.00000000","W":1616633041555}}`) - err = e.wsHandleData(payload) + err = e.wsHandleData(t.Context(), payload) if err != nil { t.Fatal(err) } @@ -2562,7 +2567,7 @@ func TestWsOrderExecutionReport(t *testing.T) { func TestWsOutboundAccountPosition(t *testing.T) { t.Parallel() payload := []byte(`{"stream":"jTfvpakT2yT0hVIo5gYWVihZhdM2PrBgJUZ5PyfZ4EVpCkx4Uoxk5timcrQc","data":{"e":"outboundAccountPosition","E":1616628815745,"u":1616628815745,"B":[{"a":"BTC","f":"0.00225109","l":"0.00123000"},{"a":"BNB","f":"0.00000000","l":"0.00000000"},{"a":"USDT","f":"54.43390661","l":"0.00000000"}]}}`) - if err := e.wsHandleData(payload); err != nil { + if err := e.wsHandleData(t.Context(), payload); err != nil { t.Fatal(err) } } diff --git a/exchanges/binance/binance_websocket.go b/exchanges/binance/binance_websocket.go index 4d98d888acb..9295f27c05c 100644 --- a/exchanges/binance/binance_websocket.go +++ b/exchanges/binance/binance_websocket.go @@ -97,7 +97,7 @@ func (e *Exchange) WsConnect() error { }) e.Websocket.Wg.Add(1) - go e.wsReadData() + go e.wsReadData(ctx) e.setupOrderbookManager(ctx) return nil @@ -133,24 +133,23 @@ func (e *Exchange) setupOrderbookManager(ctx context.Context) { func (e *Exchange) KeepAuthKeyAlive(ctx context.Context) { e.Websocket.Wg.Add(1) defer e.Websocket.Wg.Done() - ticks := time.NewTicker(time.Minute * 30) for { select { case <-e.Websocket.ShutdownC: - ticks.Stop() return - case <-ticks.C: - err := e.MaintainWsAuthStreamKey(ctx) - if err != nil { - e.Websocket.DataHandler <- err - log.Warnf(log.ExchangeSys, "%s - Unable to renew auth websocket token, may experience shutdown", e.Name) + case <-time.After(time.Minute * 30): + if err := e.MaintainWsAuthStreamKey(ctx); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } + log.Warnf(log.ExchangeSys, "%s %s: Unable to renew auth websocket token, may experience shutdown", e.Name, e.Websocket.Conn.GetURL()) } } } } // wsReadData receives and passes on websocket messages for processing -func (e *Exchange) wsReadData() { +func (e *Exchange) wsReadData(ctx context.Context) { defer e.Websocket.Wg.Done() for { @@ -158,14 +157,15 @@ func (e *Exchange) wsReadData() { if resp.Raw == nil { return } - err := e.wsHandleData(resp.Raw) - if err != nil { - e.Websocket.DataHandler <- err + if err := e.wsHandleData(ctx, resp.Raw); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } } } -func (e *Exchange) wsHandleData(respRaw []byte) error { +func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { if id, err := jsonparser.GetString(respRaw, "id"); err == nil { if e.Websocket.Match.IncomingWithData(id, respRaw) { return nil @@ -193,8 +193,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { e.Name, err) } - e.Websocket.DataHandler <- data - return nil + return e.Websocket.DataHandler.Send(ctx, data) case "balanceUpdate": var data WsBalanceUpdateData err = json.Unmarshal(jsonData, &data) @@ -203,8 +202,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { e.Name, err) } - e.Websocket.DataHandler <- data - return nil + return e.Websocket.DataHandler.Send(ctx, data) case "executionReport": var data WsOrderUpdateData err = json.Unmarshal(jsonData, &data) @@ -232,11 +230,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { var orderStatus order.Status orderStatus, err = stringToOrderStatus(data.OrderStatus) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } + return err } clientOrderID := data.ClientOrderID if orderStatus == order.Cancelled { @@ -245,22 +239,14 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { var orderType order.Type orderType, err = order.StringToOrderType(data.OrderType) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } + return err } var orderSide order.Side orderSide, err = order.StringToOrderSide(data.Side) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } + return err } - e.Websocket.DataHandler <- &order.Detail{ + return e.Websocket.DataHandler.Send(ctx, &order.Detail{ Price: data.Price, Amount: data.Quantity, AverageExecutedPrice: avgPrice, @@ -280,8 +266,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { Date: data.OrderCreationTime.Time(), LastUpdated: data.TransactionTime.Time(), Pair: pair, - } - return nil + }) case "listStatus": var data WsListStatusData err = json.Unmarshal(jsonData, &data) @@ -290,8 +275,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { e.Name, err) } - e.Websocket.DataHandler <- data - return nil + return e.Websocket.DataHandler.Send(ctx, data) } } @@ -362,7 +346,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { e.Name, err.Error()) } - e.Websocket.DataHandler <- &ticker.Price{ + return e.Websocket.DataHandler.Send(ctx, &ticker.Price{ ExchangeName: e.Name, Open: t.OpenPrice.Float64(), Close: t.ClosePrice.Float64(), @@ -376,8 +360,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { LastUpdated: t.EventTime.Time(), AssetType: asset.Spot, Pair: pair, - } - return nil + }) case "kline_1m", "kline_3m", "kline_5m", "kline_15m", "kline_30m", "kline_1h", "kline_2h", "kline_4h", "kline_6h", "kline_8h", "kline_12h", "kline_1d", "kline_3d", "kline_1w", "kline_1M": var kline KlineStream @@ -387,7 +370,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { e.Name, err) } - e.Websocket.DataHandler <- websocket.KlineData{ + return e.Websocket.DataHandler.Send(ctx, websocket.KlineData{ Timestamp: kline.EventTime.Time(), Pair: pair, AssetType: asset.Spot, @@ -400,8 +383,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { HighPrice: kline.Kline.HighPrice.Float64(), LowPrice: kline.Kline.LowPrice.Float64(), Volume: kline.Kline.Volume.Float64(), - } - return nil + }) case "depth": var depth WebsocketDepthStream err = json.Unmarshal(jsonData, &depth) @@ -586,8 +568,6 @@ func (e *Exchange) manageSubs(ctx context.Context, op string, subs subscription. if err != nil { err = fmt.Errorf("%w; Channels: %s", err, strings.Join(subs.QualifiedChannels(), ", ")) - e.Websocket.DataHandler <- err - if op == wsSubscribeMethod { if err2 := e.Websocket.RemoveSubscriptions(e.Websocket.Conn, subs...); err2 != nil { err = common.AppendError(err, err2) diff --git a/exchanges/binanceus/binanceus_test.go b/exchanges/binanceus/binanceus_test.go index b92f721198f..e80c0839073 100644 --- a/exchanges/binanceus/binanceus_test.go +++ b/exchanges/binanceus/binanceus_test.go @@ -1207,7 +1207,7 @@ func TestWebsocketSubscriptionHandling(t *testing.T) { t.Parallel() sharedtestvalues.SkipTestIfCredentialsUnset(t, e) rawData := []byte(subscriptionRequestString) - err := e.wsHandleData(rawData) + err := e.wsHandleData(t.Context(), rawData) if err != nil { t.Error("Binanceus wsHandleData() error", err) } @@ -1221,7 +1221,7 @@ func TestWebsocketUnsubscriptionHandling(t *testing.T) { ], "id": 312 }`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -1265,7 +1265,7 @@ var ticker24hourChangeStream = `{ func TestWebsocketTickerUpdate(t *testing.T) { t.Parallel() - if err := e.wsHandleData([]byte(ticker24hourChangeStream)); err != nil { + if err := e.wsHandleData(t.Context(), []byte(ticker24hourChangeStream)); err != nil { t.Error("Binanceus wsHandleData() for Ticker 24h Change Stream", err) } } @@ -1300,7 +1300,7 @@ func TestWebsocketKlineUpdate(t *testing.T) { } } }`) - if err := e.wsHandleData(pressXToJSON); err != nil { + if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { t.Error("Binanceus wsHandleData() btcusdt@kline_1m stream data conversion ", err) } } @@ -1320,7 +1320,7 @@ func TestWebsocketStreamTradeUpdate(t *testing.T) { "m": true, "M": true }}`) - if err := e.wsHandleData(pressXToJSON); err != nil { + if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { t.Error("Binanceus wsHandleData() error", err) } } @@ -1376,7 +1376,7 @@ func TestWebsocketOrderBookDepthDiffStream(t *testing.T) { if err := e.SeedLocalCacheWithBook(p, &book); err != nil { t.Fatal(err) } - if err := e.wsHandleData(update1); err != nil { + if err := e.wsHandleData(t.Context(), update1); err != nil { t.Fatal(err) } e.obm.state[currency.BTC][currency.USDT][asset.Spot].fetchingBook = false @@ -1409,7 +1409,7 @@ func TestWebsocketOrderBookDepthDiffStream(t *testing.T) { ] } }`) - if err = e.wsHandleData(update2); err != nil { + if err = e.wsHandleData(t.Context(), update2); err != nil { t.Error("Binanceus wshandlerData error", err) } ob, err = e.Websocket.Orderbook.GetOrderbook(p, asset.Spot) @@ -1451,7 +1451,7 @@ func TestWebsocketPartialOrderBookDepthStream(t *testing.T) { ] }}`) var err error - if err = e.wsHandleData(update1); err != nil { + if err = e.wsHandleData(t.Context(), update1); err != nil { t.Error("Binanceus Partial Order Book Depth Sream error", err) } update2 := []byte(`{ @@ -1472,7 +1472,7 @@ func TestWebsocketPartialOrderBookDepthStream(t *testing.T) { ] } }`) - if err = e.wsHandleData(update2); err != nil { + if err = e.wsHandleData(t.Context(), update2); err != nil { t.Error("Binanceus Partial Order Book Depth Sream error", err) } } @@ -1491,7 +1491,7 @@ func TestWebsocketBookTicker(t *testing.T) { "A":"40.66000000" } }`) - if err := e.wsHandleData(bookTickerJSON); err != nil { + if err := e.wsHandleData(t.Context(), bookTickerJSON); err != nil { t.Error("Binanceus Book Ticker error", err) } bookTickerForAllSymbols := []byte(` @@ -1506,7 +1506,7 @@ func TestWebsocketBookTicker(t *testing.T) { "A":"40.66000000" } }`) - if err := e.wsHandleData(bookTickerForAllSymbols); err != nil { + if err := e.wsHandleData(t.Context(), bookTickerForAllSymbols); err != nil { t.Error("Binanceus Web socket Book ticker for all symbols error", err) } } @@ -1530,7 +1530,7 @@ func TestWebsocketAggTrade(t *testing.T) { "M": true } }`) - if err := e.wsHandleData(aggTradejson); err != nil { + if err := e.wsHandleData(t.Context(), aggTradejson); err != nil { t.Error("Binanceus Aggregated Trade Order Json() error", err) } } @@ -1548,7 +1548,7 @@ var balanceUpdateInputJSON = ` func TestWebsocketBalanceUpdate(t *testing.T) { t.Parallel() thejson := []byte(balanceUpdateInputJSON) - if err := e.wsHandleData(thejson); err != nil { + if err := e.wsHandleData(t.Context(), thejson); err != nil { t.Error(err) } } @@ -1584,7 +1584,7 @@ var listStatusUserDataStreamPayload = ` func TestWebsocketListStatus(t *testing.T) { t.Parallel() - if err := e.wsHandleData([]byte(listStatusUserDataStreamPayload)); err != nil { + if err := e.wsHandleData(t.Context(), []byte(listStatusUserDataStreamPayload)); err != nil { t.Error(err) } } @@ -1679,15 +1679,15 @@ func TestWebsocketOrderExecutionReport(t *testing.T) { LastUpdated: time.UnixMilli(1616627567900), Pair: currency.NewBTCUSDT(), } - for len(e.Websocket.DataHandler) > 0 { - <-e.Websocket.DataHandler + for ch := e.Websocket.DataHandler.C; len(ch) > 0; { + <-ch } - err := e.wsHandleData(payload) + err := e.wsHandleData(t.Context(), payload) if err != nil { t.Fatal(err) } - res := <-e.Websocket.DataHandler - switch r := res.(type) { + res := <-e.Websocket.DataHandler.C + switch r := res.Data.(type) { case *order.Detail: if !reflects.DeepEqual(expectedResult, *r) { t.Errorf("Binanceus Results do not match:\nexpected: %v\nreceived: %v", expectedResult, *r) @@ -1696,7 +1696,7 @@ func TestWebsocketOrderExecutionReport(t *testing.T) { t.Fatalf("Binanceus expected type order.Detail, found %T", res) } payload = []byte(`{"stream":"jTfvpakT2yT0hVIo5gYWVihZhdM2PrBgJUZ5PyfZ4EVpCkx4Uoxk5timcrQc","data":{"e":"executionReport","E":1616633041556,"s":"BTCUSDT","c":"YeULctvPAnHj5HXCQo9Mob","S":"BUY","o":"LIMIT","f":"GTC","q":"0.00028600","p":"52436.85000000","P":"0.00000000","F":"0.00000000","g":-1,"C":"","x":"TRADE","X":"FILLED","r":"NONE","i":5341783271,"l":"0.00028600","z":"0.00028600","L":"52436.85000000","n":"0.00000029","N":"BTC","T":1616633041555,"t":726946523,"I":11390206312,"w":false,"m":false,"M":true,"O":1616633041555,"Z":"14.99693910","Y":"14.99693910","Q":"0.00000000"}}`) - err = e.wsHandleData(payload) + err = e.wsHandleData(t.Context(), payload) if err != nil { t.Fatal("Binanceus OrderExecutionReport json conversion error", err) } @@ -1705,7 +1705,7 @@ func TestWebsocketOrderExecutionReport(t *testing.T) { func TestWebsocketOutboundAccountPosition(t *testing.T) { t.Parallel() payload := []byte(`{"stream":"jTfvpakT2yT0hVIo5gYWVihZhdM2PrBgJUZ5PyfZ4EVpCkx4Uoxk5timcrQc","data":{"e":"outboundAccountPosition","E":1616628815745,"u":1616628815745,"B":[{"a":"BTC","f":"0.00225109","l":"0.00123000"},{"a":"BNB","f":"0.00000000","l":"0.00000000"},{"a":"USDT","f":"54.43390661","l":"0.00000000"}]}}`) - if err := e.wsHandleData(payload); err != nil { + if err := e.wsHandleData(t.Context(), payload); err != nil { t.Fatal("Binanceus testing \"outboundAccountPosition\" data conversion error", err) } } diff --git a/exchanges/binanceus/binanceus_websocket.go b/exchanges/binanceus/binanceus_websocket.go index f050d1acc56..b52426fe603 100644 --- a/exchanges/binanceus/binanceus_websocket.go +++ b/exchanges/binanceus/binanceus_websocket.go @@ -90,7 +90,7 @@ func (e *Exchange) WsConnect() error { }) e.Websocket.Wg.Add(1) - go e.wsReadData() + go e.wsReadData(ctx) e.setupOrderbookManager(ctx) return nil @@ -100,33 +100,30 @@ func (e *Exchange) WsConnect() error { // keep the WS auth key active func (e *Exchange) KeepAuthKeyAlive(ctx context.Context) { defer e.Websocket.Wg.Done() - // ClosUserDataStream closes the User data stream and remove the listen key when closing the websocket. + // CloseUserDataStream closes the User data stream and remove the listen key when closing the websocket defer func() { - er := e.CloseUserDataStream(ctx) - if er != nil { - log.Errorf(log.WebsocketMgr, "%s closing user data stream error %v", - e.Name, er) + if err := e.CloseUserDataStream(ctx); err != nil { + log.Errorf(log.WebsocketMgr, "%s closing user data stream error %v", e.Name, err) } }() - // Looping in 30 Minutes and updating the listenKey - ticks := time.NewTicker(time.Minute * 30) + for { select { case <-e.Websocket.ShutdownC: - ticks.Stop() return - case <-ticks.C: - err := e.MaintainWsAuthStreamKey(ctx) - if err != nil { - e.Websocket.DataHandler <- err - log.Warnf(log.ExchangeSys, "%s - Unable to renew auth websocket token, may experience shutdown", e.Name) + case <-time.After(time.Minute * 30): + if err := e.MaintainWsAuthStreamKey(ctx); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } + log.Warnf(log.ExchangeSys, "%s %s: Unable to renew auth websocket token, may experience shutdown", e.Name, e.Websocket.Conn.GetURL()) } } } } // wsReadData receives and passes on websocket messages for processing -func (e *Exchange) wsReadData() { +func (e *Exchange) wsReadData(ctx context.Context) { defer e.Websocket.Wg.Done() for { @@ -134,9 +131,10 @@ func (e *Exchange) wsReadData() { if resp.Raw == nil { return } - err := e.wsHandleData(resp.Raw) - if err != nil { - e.Websocket.DataHandler <- err + if err := e.wsHandleData(ctx, resp.Raw); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } } } @@ -162,7 +160,7 @@ func stringToOrderStatus(status string) (order.Status, error) { } } -func (e *Exchange) wsHandleData(respRaw []byte) error { +func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { var multiStreamData map[string]any err := json.Unmarshal(respRaw, &multiStreamData) if err != nil { @@ -194,8 +192,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { e.Name, err) } - e.Websocket.DataHandler <- data - return nil + return e.Websocket.DataHandler.Send(ctx, data) case "balanceUpdate": var data wsBalanceUpdate err := json.Unmarshal(respRaw, &data) @@ -204,8 +201,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { e.Name, err) } - e.Websocket.DataHandler <- data - return nil + return e.Websocket.DataHandler.Send(ctx, data) case "executionReport": var data wsOrderUpdate err := json.Unmarshal(respRaw, &data) @@ -230,11 +226,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { orderID := strconv.FormatInt(data.Data.OrderID, 10) orderStatus, err := stringToOrderStatus(data.Data.OrderStatus) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } + return err } clientOrderID := data.Data.ClientOrderID if orderStatus == order.Cancelled { @@ -242,21 +234,13 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { } orderType, err := order.StringToOrderType(data.Data.OrderType) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } + return err } orderSide, err := order.StringToOrderSide(data.Data.Side) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } + return err } - e.Websocket.DataHandler <- &order.Detail{ + return e.Websocket.DataHandler.Send(ctx, &order.Detail{ Price: data.Data.Price, Amount: data.Data.Quantity, AverageExecutedPrice: averagePrice, @@ -276,8 +260,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { Date: data.Data.OrderCreationTime.Time(), LastUpdated: data.Data.TransactionTime.Time(), Pair: pair, - } - return nil + }) case "listStatus": var data WsListStatus err := json.Unmarshal(respRaw, &data) @@ -286,8 +269,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { e.Name, err) } - e.Websocket.DataHandler <- data - return nil + return e.Websocket.DataHandler.Send(ctx, data) } } } @@ -356,7 +338,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { return err } - e.Websocket.DataHandler <- &ticker.Price{ + return e.Websocket.DataHandler.Send(ctx, &ticker.Price{ ExchangeName: e.Name, Open: t.OpenPrice, Close: t.ClosePrice, @@ -370,8 +352,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { LastUpdated: t.EventTime.Time(), AssetType: asset.Spot, Pair: pair, - } - return nil + }) case "kline_1m", "kline_3m", "kline_5m", "kline_15m", "kline_30m", "kline_1h", "kline_2h", "kline_4h", "kline_6h", "kline_8h", "kline_12h", "kline_1d", "kline_3d", "kline_1w", "kline_1M": var kline KlineStream @@ -387,7 +368,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { return err } - e.Websocket.DataHandler <- websocket.KlineData{ + return e.Websocket.DataHandler.Send(ctx, websocket.KlineData{ Timestamp: kline.EventTime.Time(), Pair: pair, AssetType: asset.Spot, @@ -400,8 +381,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { HighPrice: kline.Kline.HighPrice, LowPrice: kline.Kline.LowPrice, Volume: kline.Kline.Volume, - } - return nil + }) case "depth": var depth WebsocketDepthStream err := json.Unmarshal(rawData, &depth) @@ -428,8 +408,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { e.Name, err) } - e.Websocket.DataHandler <- depth - return nil + return e.Websocket.DataHandler.Send(ctx, &depth) case "bookTicker": var bo OrderBookTickerStream err := json.Unmarshal(rawData, &bo) @@ -441,20 +420,18 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { return err } bo.Symbol = pair - e.Websocket.DataHandler <- &bo - return nil + return e.Websocket.DataHandler.Send(ctx, &bo) case "aggTrade": var agg WebsocketAggregateTradeStream err := json.Unmarshal(rawData, &agg) if err != nil { return fmt.Errorf("%v - Could not convert to aggTrade structure %s ", err, e.Name) } - e.Websocket.DataHandler <- agg - return nil + return e.Websocket.DataHandler.Send(ctx, &agg) default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{ Message: e.Name + websocket.UnhandledMessage + string(respRaw), - } + }) } } } else if wsStream == "!bookTicker" { @@ -482,8 +459,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { return err } bt.Symbol = pair - e.Websocket.DataHandler <- &bt - return nil + return e.Websocket.DataHandler.Send(ctx, &bt) } } } diff --git a/exchanges/bitfinex/bitfinex_test.go b/exchanges/bitfinex/bitfinex_test.go index 90f163f1d73..e3921bf5f6e 100644 --- a/exchanges/bitfinex/bitfinex_test.go +++ b/exchanges/bitfinex/bitfinex_test.go @@ -1120,8 +1120,8 @@ func TestWSAuth(t *testing.T) { var resp map[string]any catcher := func() (ok bool) { select { - case v := <-e.Websocket.ToRoutine: - resp, ok = v.(map[string]any) + case v := <-e.Websocket.DataHandler.C: + resp, ok = v.Data.(map[string]any) default: } return @@ -1189,8 +1189,8 @@ func TestWSSubscribe(t *testing.T) { err := e.Subscribe(subscription.List{{Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewBTCUSD()}, Asset: asset.Spot}}) require.NoError(t, err, "Subrcribe must not error") catcher := func() (ok bool) { - i := <-e.Websocket.ToRoutine - _, ok = i.(*ticker.Price) + i := <-e.Websocket.DataHandler.C + _, ok = i.Data.(*ticker.Price) return } assert.Eventually(t, catcher, sharedtestvalues.WebsocketResponseDefaultTimeout, time.Millisecond*10, "Ticker response should arrive") @@ -1202,13 +1202,6 @@ func TestWSSubscribe(t *testing.T) { err = e.Subscribe(subscription.List{{Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewBTCUSD()}, Asset: asset.Spot}}) require.ErrorContains(t, err, "subscribe: dup (code: 10301)", "Duplicate subscription must error correctly") - assert.EventuallyWithT(t, func(t *assert.CollectT) { - i := <-e.Websocket.ToRoutine - e, ok := i.(error) - require.True(t, ok, "must find an error") - assert.ErrorContains(t, e, "subscribe: dup (code: 10301)", "error should be correct") - }, sharedtestvalues.WebsocketResponseDefaultTimeout, time.Millisecond*10, "error response should go to ToRoutine") - subs, err = e.GetSubscriptions() require.NoError(t, err, "GetSubscriptions must not error") require.Len(t, subs, 1, "We must only have one subscription after an error attempt") @@ -1379,7 +1372,7 @@ func TestWSAllTrades(t *testing.T) { err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: subscription.AllTradesChannel, Key: 18788}) require.NoError(t, err, "AddSubscriptions must not error") testexch.FixtureToDataHandler(t, "testdata/wsAllTrades.json", e.wsHandleData) - close(e.Websocket.DataHandler) + e.Websocket.DataHandler.Close() expJSON := []string{ `{"TID":"412685577","AssetType":"spot","Side":"BUY","Price":176.3,"Amount":11.1998,"Timestamp":"2020-01-29T03:27:24.802Z"}`, `{"TID":"412685578","AssetType":"spot","Side":"SELL","Price":176.29952759,"Amount":5,"Timestamp":"2020-01-29T03:28:04.802Z"}`, @@ -1389,11 +1382,11 @@ func TestWSAllTrades(t *testing.T) { `{"TID":"5690221203","AssetType":"marginFunding","Side":"BUY","Price":102550,"Amount":0.00991467,"Timestamp":"2024-12-15T04:30:18.019Z"}`, `{"TID":"5690221204","AssetType":"marginFunding","Side":"SELL","Price":102540,"Amount":0.01925285,"Timestamp":"2024-12-15T04:30:18.094Z"}`, } - require.Len(t, e.Websocket.DataHandler, len(expJSON), "Must see correct number of trades") - for resp := range e.Websocket.DataHandler { - switch v := resp.(type) { + require.Len(t, e.Websocket.DataHandler.C, len(expJSON), "Must see correct number of trades") + for resp := range e.Websocket.DataHandler.C { + switch v := resp.Data.(type) { case trade.Data: - i := 6 - len(e.Websocket.DataHandler) + i := 6 - len(e.Websocket.DataHandler.C) exp := trade.Data{ Exchange: e.Name, CurrencyPair: btcusdPair, diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index 37d28196e6f..d2ae8bc3fe2 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -101,8 +101,6 @@ var defaultSubscriptions = subscription.List{ {Enabled: true, Channel: subscription.OrderbookChannel, Asset: asset.All, Levels: 100, Params: map[string]any{"prec": "R0"}}, } -var comms = make(chan websocket.Response) - type checksum struct { Token uint32 Sequence int64 @@ -136,7 +134,7 @@ func (e *Exchange) WsConnect() error { } e.Websocket.Wg.Add(1) - go e.wsReadData(e.Websocket.Conn) + go e.wsReadData(ctx, e.Websocket.Conn) if e.Websocket.CanUseAuthenticatedEndpoints() { err = e.Websocket.AuthConn.Dial(ctx, &dialer, http.Header{}) if err != nil { @@ -147,7 +145,7 @@ func (e *Exchange) WsConnect() error { e.Websocket.SetCanUseAuthenticatedEndpoints(false) } e.Websocket.Wg.Add(1) - go e.wsReadData(e.Websocket.AuthConn) + go e.wsReadData(ctx, e.Websocket.AuthConn) err = e.WsSendAuth(ctx) if err != nil { log.Errorf(log.ExchangeSys, @@ -159,61 +157,33 @@ func (e *Exchange) WsConnect() error { } e.Websocket.Wg.Add(1) - go e.WsDataHandler(ctx) return e.ConfigureWS(ctx) } // wsReadData receives and passes on websocket messages for processing -func (e *Exchange) wsReadData(ws websocket.Connection) { +func (e *Exchange) wsReadData(ctx context.Context, ws websocket.Connection) { defer e.Websocket.Wg.Done() for { resp := ws.ReadMessage() if resp.Raw == nil { return } - comms <- resp - } -} - -// WsDataHandler handles data from wsReadData -func (e *Exchange) WsDataHandler(ctx context.Context) { - defer e.Websocket.Wg.Done() - for { - select { - case <-e.Websocket.ShutdownC: - select { - case resp := <-comms: - err := e.wsHandleData(ctx, resp.Raw) - if err != nil { - select { - case e.Websocket.DataHandler <- err: - default: - log.Errorf(log.WebsocketMgr, "%s websocket handle data error: %v", e.Name, err) - } - } - default: - } - return - case resp := <-comms: - if resp.Type != gws.TextMessage { - continue - } - err := e.wsHandleData(ctx, resp.Raw) - if err != nil { - e.Websocket.DataHandler <- err + if err := e.wsHandleData(ctx, resp.Raw); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) } } } } -func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { +func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { var result any if err := json.Unmarshal(respRaw, &result); err != nil { return err } switch d := result.(type) { case map[string]any: - return e.handleWSEvent(respRaw) + return e.handleWSEvent(ctx, respRaw) case []any: chanIDFloat, ok := d[0].(float64) if !ok { @@ -225,7 +195,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { if chanID != 0 { if s := e.Websocket.GetSubscription(chanID); s != nil { - return e.handleWSChannelUpdate(s, respRaw, eventType, d) + return e.handleWSChannelUpdate(ctx, s, respRaw, eventType, d) } if e.Verbose { log.Warnf(log.ExchangeSys, "%s %s; dropped WS message: %s", e.Name, subscription.ErrNotFound, respRaw) @@ -244,27 +214,29 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { case wsHeartbeat, pong: return nil case wsNotification: - return e.handleWSNotification(d, respRaw) + return e.handleWSNotification(ctx, d, respRaw) case wsOrderSnapshot: if snapBundle, ok := d[2].([]any); ok && len(snapBundle) > 0 { if _, ok := snapBundle[0].([]any); ok { for i := range snapBundle { if positionData, ok := snapBundle[i].([]any); ok { - e.wsHandleOrder(positionData) + if err := e.wsHandleOrder(ctx, positionData); err != nil { + return err + } } } } } case wsOrderCancel, wsOrderNew, wsOrderUpdate: if oData, ok := d[2].([]any); ok && len(oData) > 0 { - e.wsHandleOrder(oData) + return e.wsHandleOrder(ctx, oData) } case wsPositionSnapshot: - return e.handleWSPositionSnapshot(d) + return e.handleWSPositionSnapshot(ctx, d) case wsPositionNew, wsPositionUpdate, wsPositionClose: - return e.handleWSPositionUpdate(d) + return e.handleWSPositionUpdate(ctx, d) case wsTradeExecuted, wsTradeUpdated: - return e.handleWSMyTradeUpdate(d, eventType) + return e.handleWSMyTradeUpdate(ctx, d, eventType) case wsFundingOfferSnapshot: if snapBundle, ok := d[2].([]any); ok && len(snapBundle) > 0 { if _, ok := snapBundle[0].([]any); ok { @@ -280,7 +252,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { } snapshot[i] = offer } - e.Websocket.DataHandler <- snapshot + return e.Websocket.DataHandler.Send(ctx, snapshot) } } case wsFundingOfferNew, wsFundingOfferUpdate, wsFundingOfferCancel: @@ -289,7 +261,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- offer + return e.Websocket.DataHandler.Send(ctx, offer) } case wsFundingCreditSnapshot: if snapBundle, ok := d[2].([]any); ok && len(snapBundle) > 0 { @@ -306,7 +278,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { } snapshot[i] = fundingCredit } - e.Websocket.DataHandler <- snapshot + return e.Websocket.DataHandler.Send(ctx, snapshot) } } case wsFundingCreditNew, wsFundingCreditUpdate, wsFundingCreditCancel: @@ -315,7 +287,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- fundingCredit + return e.Websocket.DataHandler.Send(ctx, fundingCredit) } case wsFundingLoanSnapshot: if snapBundle, ok := d[2].([]any); ok && len(snapBundle) > 0 { @@ -332,7 +304,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { } snapshot[i] = fundingLoanSnapshot } - e.Websocket.DataHandler <- snapshot + return e.Websocket.DataHandler.Send(ctx, snapshot) } } case wsFundingLoanNew, wsFundingLoanUpdate, wsFundingLoanCancel: @@ -341,7 +313,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- fundingData + return e.Websocket.DataHandler.Send(ctx, fundingData) } case wsWalletSnapshot: if snapBundle, ok := d[2].([]any); ok && len(snapBundle) > 0 { @@ -372,7 +344,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { } snapshot[i] = wallet } - e.Websocket.DataHandler <- snapshot + return e.Websocket.DataHandler.Send(ctx, snapshot) } } case wsWalletUpdate: @@ -395,7 +367,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { return errors.New("unable to type assert wallet snapshot balance available") } } - e.Websocket.DataHandler <- wallet + return e.Websocket.DataHandler.Send(ctx, wallet) } case wsBalanceUpdate: if data, ok := d[2].([]any); ok && len(data) > 0 { @@ -406,7 +378,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { if balance.NetAssetsUnderManagement, ok = data[1].(float64); !ok { return errors.New("unable to type assert balance net assets under management") } - e.Websocket.DataHandler <- balance + return e.Websocket.DataHandler.Send(ctx, balance) } case wsMarginInfoUpdate: if data, ok := d[2].([]any); ok && len(data) > 0 { @@ -431,7 +403,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { if marginInfoBase.MarginRequired, ok = baseData[4].(float64); !ok { return errors.New("unable to type assert margin info required") } - e.Websocket.DataHandler <- marginInfoBase + return e.Websocket.DataHandler.Send(ctx, marginInfoBase) } } case wsFundingInfoUpdate: @@ -457,7 +429,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { if fundingInfo.DurationLend, ok = symbolData[3].(float64); !ok { return errors.New("unable to type assert funding info update duration lend") } - e.Websocket.DataHandler <- fundingInfo + return e.Websocket.DataHandler.Send(ctx, fundingInfo) } } case wsFundingTradeExecuted, wsFundingTradeUpdated: @@ -493,19 +465,18 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { } wsFundingTrade.Period = int64(period) wsFundingTrade.Maker = data[7] != nil - e.Websocket.DataHandler <- wsFundingTrade + return e.Websocket.DataHandler.Send(ctx, wsFundingTrade) } default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{ Message: e.Name + websocket.UnhandledMessage + string(respRaw), - } - return nil + }) } } return nil } -func (e *Exchange) handleWSEvent(respRaw []byte) error { +func (e *Exchange) handleWSEvent(ctx context.Context, respRaw []byte) error { event, err := jsonparser.GetUnsafeString(respRaw, "event") if err != nil { return fmt.Errorf("%w 'event': %w from message: %s", common.ErrParsingWSField, err, respRaw) @@ -547,14 +518,13 @@ func (e *Exchange) handleWSEvent(respRaw []byte) error { return fmt.Errorf("unable to Unmarshal auth resp; Error: %w Msg: %v", err, respRaw) } // TODO - Send a better value down the channel - e.Websocket.DataHandler <- glob - } else { - errCode, err := jsonparser.GetInt(respRaw, "code") - if err != nil { - log.Errorf(log.ExchangeSys, "%s %s 'code': %s from message: %s", e.Name, common.ErrParsingWSField, err, respRaw) - } - return fmt.Errorf("WS auth subscription error; Status: %s Error Code: %d", status, errCode) + return e.Websocket.DataHandler.Send(ctx, glob) + } + errCode, err := jsonparser.GetInt(respRaw, "code") + if err != nil { + log.Errorf(log.ExchangeSys, "%s %s 'code': %s from message: %s", e.Name, common.ErrParsingWSField, err, respRaw) } + return fmt.Errorf("WS auth subscription error; Status: %s Error Code: %d", status, errCode) case wsEventInfo: // Nothing to do with info for now. // version or platform.status might be useful in the future. @@ -608,7 +578,7 @@ func (e *Exchange) handleWSSubscribed(respRaw []byte) error { return e.Websocket.Match.RequireMatchWithData("subscribe:"+subID, respRaw) } -func (e *Exchange) handleWSChannelUpdate(s *subscription.Subscription, respRaw []byte, eventType string, d []any) error { +func (e *Exchange) handleWSChannelUpdate(ctx context.Context, s *subscription.Subscription, respRaw []byte, eventType string, d []any) error { if s == nil { return fmt.Errorf("%w: Subscription param", common.ErrNilPointer) } @@ -626,13 +596,13 @@ func (e *Exchange) handleWSChannelUpdate(s *subscription.Subscription, respRaw [ switch s.Channel { case subscription.OrderbookChannel: - return e.handleWSBookUpdate(s, d) + return e.handleWSBookUpdate(ctx, s, d) case subscription.CandlesChannel: - return e.handleWSAllCandleUpdates(s, respRaw) + return e.handleWSAllCandleUpdates(ctx, s, respRaw) case subscription.TickerChannel: - return e.handleWSTickerUpdate(s, d) + return e.handleWSTickerUpdate(ctx, s, d) case subscription.AllTradesChannel: - return e.handleWSAllTrades(s, respRaw) + return e.handleWSAllTrades(ctx, s, respRaw) } return fmt.Errorf("%s unhandled channel update: %s", e.Name, s.Channel) @@ -672,7 +642,7 @@ func (e *Exchange) handleWSChecksum(c *subscription.Subscription, d []any) error return nil } -func (e *Exchange) handleWSBookUpdate(c *subscription.Subscription, d []any) error { +func (e *Exchange) handleWSBookUpdate(ctx context.Context, c *subscription.Subscription, d []any) error { if c == nil { return fmt.Errorf("%w: Subscription param", common.ErrNilPointer) } @@ -768,7 +738,7 @@ func (e *Exchange) handleWSBookUpdate(c *subscription.Subscription, d []any) err }) } - if err := e.WsUpdateOrderbook(c, c.Pairs[0], c.Asset, newOrderbook, int64(sequenceNo), fundingRate); err != nil { + if err := e.WsUpdateOrderbook(ctx, c, c.Pairs[0], c.Asset, newOrderbook, int64(sequenceNo), fundingRate); err != nil { return fmt.Errorf("updating orderbook error: %s", err) } @@ -777,7 +747,7 @@ func (e *Exchange) handleWSBookUpdate(c *subscription.Subscription, d []any) err return nil } -func (e *Exchange) handleWSAllCandleUpdates(c *subscription.Subscription, respRaw []byte) error { +func (e *Exchange) handleWSAllCandleUpdates(ctx context.Context, c *subscription.Subscription, respRaw []byte) error { if c == nil { return fmt.Errorf("%w: Subscription param", common.ErrNilPointer) } @@ -818,11 +788,10 @@ func (e *Exchange) handleWSAllCandleUpdates(c *subscription.Subscription, respRa Volume: wsCandles[i].Volume.Float64(), } } - e.Websocket.DataHandler <- klines - return nil + return e.Websocket.DataHandler.Send(ctx, klines) } -func (e *Exchange) handleWSTickerUpdate(c *subscription.Subscription, d []any) error { +func (e *Exchange) handleWSTickerUpdate(ctx context.Context, c *subscription.Subscription, d []any) error { if c == nil { return fmt.Errorf("%w: Subscription param", common.ErrNilPointer) } @@ -897,11 +866,10 @@ func (e *Exchange) handleWSTickerUpdate(c *subscription.Subscription, d []any) e return errors.New("unable to type assert ticker flash return rate") } } - e.Websocket.DataHandler <- t - return nil + return e.Websocket.DataHandler.Send(ctx, t) } -func (e *Exchange) handleWSAllTrades(s *subscription.Subscription, respRaw []byte) error { +func (e *Exchange) handleWSAllTrades(ctx context.Context, s *subscription.Subscription, respRaw []byte) error { feedEnabled := e.IsTradeFeedEnabled() if !feedEnabled && !e.IsSaveTradeDataEnabled() { return nil @@ -948,7 +916,9 @@ func (e *Exchange) handleWSAllTrades(s *subscription.Subscription, respRaw []byt t.Price = w.Rate } if feedEnabled { - e.Websocket.DataHandler <- t + if err := e.Websocket.DataHandler.Send(ctx, t); err != nil { + return err + } } } if e.IsSaveTradeDataEnabled() { @@ -971,7 +941,7 @@ func (e *Exchange) handleWSPublicTradeUpdate(respRaw []byte) (*Trade, error) { return t, json.Unmarshal(v, t) } -func (e *Exchange) handleWSNotification(d []any, respRaw []byte) error { +func (e *Exchange) handleWSNotification(ctx context.Context, d []any, respRaw []byte) error { notification, ok := d[2].([]any) if !ok { return errors.New("unable to type assert notification data") @@ -994,7 +964,7 @@ func (e *Exchange) handleWSNotification(d []any, respRaw []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- offer + return e.Websocket.DataHandler.Send(ctx, offer) } } case strings.Contains(channelName, wsOrderNewRequest): @@ -1005,7 +975,7 @@ func (e *Exchange) handleWSNotification(d []any, respRaw []byte) error { if e.Websocket.Match.IncomingWithData(int64(cid), respRaw) { return nil } - e.wsHandleOrder(data) + return e.wsHandleOrder(ctx, data) } } case strings.Contains(channelName, wsOrderUpdateRequest), @@ -1017,7 +987,7 @@ func (e *Exchange) handleWSNotification(d []any, respRaw []byte) error { if e.Websocket.Match.IncomingWithData(int64(id), respRaw) { return nil } - e.wsHandleOrder(data) + return e.wsHandleOrder(ctx, data) } } default: @@ -1042,7 +1012,7 @@ func (e *Exchange) handleWSNotification(d []any, respRaw []byte) error { return nil } -func (e *Exchange) handleWSPositionSnapshot(d []any) error { +func (e *Exchange) handleWSPositionSnapshot(ctx context.Context, d []any) error { snapBundle, ok := d[2].([]any) if !ok { return common.GetTypeAssertError("[]any", d[2], "positionSnapshotBundle") @@ -1091,11 +1061,10 @@ func (e *Exchange) handleWSPositionSnapshot(d []any) error { } snapshot[i] = position } - e.Websocket.DataHandler <- snapshot - return nil + return e.Websocket.DataHandler.Send(ctx, snapshot) } -func (e *Exchange) handleWSPositionUpdate(d []any) error { +func (e *Exchange) handleWSPositionUpdate(ctx context.Context, d []any) error { positionData, ok := d[2].([]any) if !ok { return common.GetTypeAssertError("[]any", d[2], "positionUpdate") @@ -1136,11 +1105,10 @@ func (e *Exchange) handleWSPositionUpdate(d []any) error { if position.Leverage, ok = positionData[9].(float64); !ok { return errors.New("unable to type assert position leverage") } - e.Websocket.DataHandler <- position - return nil + return e.Websocket.DataHandler.Send(ctx, position) } -func (e *Exchange) handleWSMyTradeUpdate(d []any, eventType string) error { +func (e *Exchange) handleWSMyTradeUpdate(ctx context.Context, d []any, eventType string) error { tradeData, ok := d[2].([]any) if !ok { return common.GetTypeAssertError("[]any", d[2], "tradeUpdate") @@ -1192,8 +1160,7 @@ func (e *Exchange) handleWSMyTradeUpdate(d []any, eventType string) error { return errors.New("unable to type assert trade fee currency") } } - e.Websocket.DataHandler <- tData - return nil + return e.Websocket.DataHandler.Send(ctx, tData) } func wsHandleFundingOffer(data []any, includeRateReal bool) (*WsFundingOffer, error) { @@ -1409,7 +1376,7 @@ func wsHandleFundingCreditLoanData(data []any, includePositionPair bool) (*WsCre return &credit, nil } -func (e *Exchange) wsHandleOrder(data []any) { +func (e *Exchange) wsHandleOrder(ctx context.Context, data []any) error { var od order.Detail var err error od.Exchange = e.Name @@ -1452,8 +1419,7 @@ func (e *Exchange) wsHandleOrder(data []any) { if p, ok := data[3].(string); ok { od.Pair, od.AssetType, err = e.GetRequestFormattedPairAndAssetType(p[1:]) if err != nil { - e.Websocket.DataHandler <- err - return + return err } } } @@ -1461,11 +1427,7 @@ func (e *Exchange) wsHandleOrder(data []any) { if ordType, ok := data[8].(string); ok { oType, err := order.StringToOrderType(ordType) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: od.OrderID, - Err: err, - } + return err } od.Type = oType } @@ -1475,16 +1437,12 @@ func (e *Exchange) wsHandleOrder(data []any) { statusParts := strings.Split(combinedStatus, " @ ") oStatus, err := order.StringToOrderStatus(statusParts[0]) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: od.OrderID, - Err: err, - } + return err } od.Status = oStatus } } - e.Websocket.DataHandler <- &od + return e.Websocket.DataHandler.Send(ctx, &od) } // WsInsertSnapshot add the initial orderbook snapshot when subscribed to a channel @@ -1531,7 +1489,7 @@ func (e *Exchange) WsInsertSnapshot(p currency.Pair, assetType asset.Item, books // WsUpdateOrderbook updates the orderbook list, removing and adding to the // orderbook sides -func (e *Exchange) WsUpdateOrderbook(c *subscription.Subscription, p currency.Pair, assetType asset.Item, book []WebsocketBook, sequenceNo int64, fundingRate bool) error { +func (e *Exchange) WsUpdateOrderbook(ctx context.Context, c *subscription.Subscription, p currency.Pair, assetType asset.Item, book []WebsocketBook, sequenceNo int64, fundingRate bool) error { if c == nil { return fmt.Errorf("%w: Subscription param", common.ErrNilPointer) } @@ -1620,7 +1578,7 @@ func (e *Exchange) WsUpdateOrderbook(c *subscription.Subscription, p currency.Pa if err = validateCRC32(ob, checkme.Token); err != nil { log.Errorf(log.WebsocketMgr, "%s websocket orderbook update error, will resubscribe orderbook: %v", e.Name, err) - if e2 := e.resubOrderbook(c); e2 != nil { + if e2 := e.resubOrderbook(ctx, c); e2 != nil { log.Errorf(log.WebsocketMgr, "%s error resubscribing orderbook: %v", e.Name, e2) } return err @@ -1633,7 +1591,7 @@ func (e *Exchange) WsUpdateOrderbook(c *subscription.Subscription, p currency.Pa // resubOrderbook resubscribes the orderbook after a consistency error, probably a failed checksum, // which forces a fresh snapshot. If we don't do this the orderbook will keep erroring and drifting. // Flushing the orderbook happens immediately, but the ReSub itself is a go routine to avoid blocking the WS data channel -func (e *Exchange) resubOrderbook(c *subscription.Subscription) error { +func (e *Exchange) resubOrderbook(ctx context.Context, c *subscription.Subscription) error { if c == nil { return fmt.Errorf("%w: Subscription param", common.ErrNilPointer) } @@ -1647,7 +1605,7 @@ func (e *Exchange) resubOrderbook(c *subscription.Subscription) error { // Resub will block so we have to do this in a goro go func() { - if err := e.Websocket.ResubscribeToChannel(e.Websocket.Conn, c); err != nil { + if err := e.Websocket.ResubscribeToChannel(ctx, e.Websocket.Conn, c); err != nil { log.Errorf(log.ExchangeSys, "%s error resubscribing orderbook: %v", e.Name, err) } }() @@ -1737,9 +1695,7 @@ func (e *Exchange) subscribeToChan(ctx context.Context, subs subscription.List) } if err = e.getErrResp(respRaw); err != nil { - wErr := fmt.Errorf("%w: Channel: %s Pair: %s", err, s.Channel, s.Pairs) - e.Websocket.DataHandler <- wErr - return wErr + return fmt.Errorf("%w: Channel: %s Pair: %s", err, s.Channel, s.Pairs) } return nil @@ -1767,9 +1723,7 @@ func (e *Exchange) unsubscribeFromChan(ctx context.Context, subs subscription.Li } if err := e.getErrResp(respRaw); err != nil { - wErr := fmt.Errorf("%w: ChanId: %v", err, chanID) - e.Websocket.DataHandler <- wErr - return wErr + return fmt.Errorf("%w: ChanId: %v", err, chanID) } return e.Websocket.RemoveSubscriptions(e.Websocket.Conn, s) diff --git a/exchanges/bithumb/bithumb_websocket.go b/exchanges/bithumb/bithumb_websocket.go index 8ac317075a4..b696b6efae8 100644 --- a/exchanges/bithumb/bithumb_websocket.go +++ b/exchanges/bithumb/bithumb_websocket.go @@ -20,6 +20,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/log" ) const ( @@ -51,14 +52,14 @@ func (e *Exchange) WsConnect() error { } e.Websocket.Wg.Add(1) - go e.wsReadData() + go e.wsReadData(ctx) e.setupOrderbookManager(ctx) return nil } // wsReadData receives and passes on websocket messages for processing -func (e *Exchange) wsReadData() { +func (e *Exchange) wsReadData(ctx context.Context) { defer e.Websocket.Wg.Done() for { @@ -70,15 +71,16 @@ func (e *Exchange) wsReadData() { if resp.Raw == nil { return } - err := e.wsHandleData(resp.Raw) - if err != nil { - e.Websocket.DataHandler <- err + if err := e.wsHandleData(ctx, resp.Raw); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } } } } -func (e *Exchange) wsHandleData(respRaw []byte) error { +func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { var resp WsResponse err := json.Unmarshal(respRaw, &resp) if err != nil { @@ -106,7 +108,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- &ticker.Price{ + return e.Websocket.DataHandler.Send(ctx, &ticker.Price{ ExchangeName: e.Name, AssetType: asset.Spot, Last: tick.PreviousClosePrice, @@ -118,7 +120,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { QuoteVolume: tick.Value, Volume: tick.Volume, LastUpdated: lu, - } + }) case "transaction": if !e.IsSaveTradeDataEnabled() { return nil diff --git a/exchanges/bithumb/bithumb_websocket_test.go b/exchanges/bithumb/bithumb_websocket_test.go index 3cc801ce11c..b8c2113bef6 100644 --- a/exchanges/bithumb/bithumb_websocket_test.go +++ b/exchanges/bithumb/bithumb_websocket_test.go @@ -55,21 +55,21 @@ func TestWsHandleData(t *testing.T) { dummy.API.Endpoints = e.NewEndpoints() welcomeMsg := []byte(`{"status":"0000","resmsg":"Connected Successfully"}`) - err := dummy.wsHandleData(welcomeMsg) + err := dummy.wsHandleData(t.Context(), welcomeMsg) require.NoError(t, err) - err = dummy.wsHandleData([]byte(`{"status":"1336","resmsg":"Failed"}`)) + err = dummy.wsHandleData(t.Context(), []byte(`{"status":"1336","resmsg":"Failed"}`)) require.ErrorIs(t, err, websocket.ErrSubscriptionFailure) - err = dummy.wsHandleData(wsTransResp) + err = dummy.wsHandleData(t.Context(), wsTransResp) require.NoError(t, err) - err = dummy.wsHandleData(wsOrderbookResp) + err = dummy.wsHandleData(t.Context(), wsOrderbookResp) require.NoError(t, err) - err = dummy.wsHandleData(wsTickerResp) + err = dummy.wsHandleData(t.Context(), wsTickerResp) require.NoError(t, err) - assert.IsType(t, new(ticker.Price), <-dummy.Websocket.DataHandler, "ticker should send a price to the DataHandler") + assert.IsType(t, new(ticker.Price), (<-dummy.Websocket.DataHandler.C).Data, "ticker should send a price to the DataHandler") } func TestSubToReq(t *testing.T) { diff --git a/exchanges/bitmex/bitmex_test.go b/exchanges/bitmex/bitmex_test.go index 0cf7cc77196..263ef49e0c1 100644 --- a/exchanges/bitmex/bitmex_test.go +++ b/exchanges/bitmex/bitmex_test.go @@ -724,13 +724,13 @@ func TestWsAuth(t *testing.T) { err := e.Websocket.Conn.Dial(t.Context(), &dialer, http.Header{}) require.NoError(t, err) - go e.wsReadData() + go e.wsReadData(t.Context()) err = e.websocketSendAuth(t.Context()) require.NoError(t, err) timer := time.NewTimer(sharedtestvalues.WebsocketResponseDefaultTimeout) select { - case resp := <-e.Websocket.DataHandler: - sub, ok := resp.(WebsocketSubscribeResp) + case resp := <-e.Websocket.DataHandler.C: + sub, ok := resp.Data.(WebsocketSubscribeResp) if !ok { t.Fatal("unable to type assert WebsocketSubscribeResp") } @@ -760,7 +760,7 @@ func TestWsPositionUpdate(t *testing.T) { "unrealisedGrossPnl":-677,"unrealisedPnl":-677,"unrealisedPnlPcnt":-0.0078,"unrealisedRoePcnt":-0.7756, "simpleQty":0.001,"liquidationPrice":1140.1, "timestamp":"2017-04-04T22:07:45.442Z" }]}`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) require.NoError(t, err) } @@ -782,7 +782,7 @@ func TestWsInsertExectuionUpdate(t *testing.T) { "homeNotional":-0.00088155,"foreignNotional":1,"transactTime":"2017-04-04T22:07:46.035Z", "timestamp":"2017-04-04T22:07:46.035Z" }]}`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) require.NoError(t, err) } @@ -795,7 +795,7 @@ func TestWSPositionUpdateHandling(t *testing.T) { "markPrice":1136.88,"posState":"Liquidated","simpleQty":0.001,"liquidationPrice":1140.1,"bankruptPrice":1134.37, "timestamp":"2017-04-04T22:07:46.019Z" }]}`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) require.NoError(t, err) pressXToJSON = []byte(`{"table":"position", "action":"update", @@ -811,7 +811,7 @@ func TestWSPositionUpdateHandling(t *testing.T) { "avgEntryPrice":null,"breakEvenPrice":null,"marginCallPrice":null,"liquidationPrice":null,"bankruptPrice":null, "timestamp":"2017-04-04T22:07:46.140Z" }]}`) - err = e.wsHandleData(pressXToJSON) + err = e.wsHandleData(t.Context(), pressXToJSON) require.NoError(t, err) } @@ -832,7 +832,7 @@ func TestWSOrderbookHandling(t *testing.T) { {"symbol":"ETHUSD","id":17999996000,"side":"Buy","size":20,"price":40}, {"symbol":"ETHUSD","id":17999997000,"side":"Buy","size":100,"price":30} ]}`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) require.NoError(t, err) pressXToJSON = []byte(`{ @@ -841,14 +841,14 @@ func TestWSOrderbookHandling(t *testing.T) { "data":[ {"symbol":"ETHUSD","id":17999995000,"side":"Buy","size":5,"timestamp":"2017-04-04T22:16:38.461Z"} ]}`) - err = e.wsHandleData(pressXToJSON) + err = e.wsHandleData(t.Context(), pressXToJSON) require.NoError(t, err) pressXToJSON = []byte(`{ "table":"orderBookL2_25", "action":"update", "data":[]}`) - err = e.wsHandleData(pressXToJSON) + err = e.wsHandleData(t.Context(), pressXToJSON) require.ErrorContains(t, err, "empty orderbook") pressXToJSON = []byte(`{ @@ -857,7 +857,7 @@ func TestWSOrderbookHandling(t *testing.T) { "data":[ {"symbol":"ETHUSD","id":17999995000,"side":"Buy","timestamp":"2017-04-04T22:16:38.461Z"} ]}`) - err = e.wsHandleData(pressXToJSON) + err = e.wsHandleData(t.Context(), pressXToJSON) require.NoError(t, err) pressXToJSON = []byte(`{ @@ -866,7 +866,7 @@ func TestWSOrderbookHandling(t *testing.T) { "data":[ {"symbol":"ETHUSD","id":17999995000,"side":"Buy","timestamp":"2017-04-04T22:16:38.461Z"} ]}`) - err = e.wsHandleData(pressXToJSON) + err = e.wsHandleData(t.Context(), pressXToJSON) assert.ErrorIs(t, err, orderbook.ErrOrderbookInvalid) } @@ -879,7 +879,7 @@ func TestWSDeleveragePositionUpdateHandling(t *testing.T) { "markPrice":1160.72,"posState":"Deleverage","simpleQty":1.746,"liquidationPrice":1140.1, "timestamp":"2017-04-04T22:16:38.460Z" }]}`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) require.NoError(t, err) pressXToJSON = []byte(`{"table":"position", @@ -897,7 +897,7 @@ func TestWSDeleveragePositionUpdateHandling(t *testing.T) { "avgEntryPrice":null,"breakEvenPrice":null,"marginCallPrice":null,"liquidationPrice":null,"bankruptPrice":null, "timestamp":"2017-04-04T22:16:38.547Z" }]}`) - err = e.wsHandleData(pressXToJSON) + err = e.wsHandleData(t.Context(), pressXToJSON) require.NoError(t, err) } @@ -919,7 +919,7 @@ func TestWSDeleverageExecutionInsertHandling(t *testing.T) { "homeNotional":-1.72306,"foreignNotional":2000,"transactTime":"2017-04-04T22:16:38.472Z", "timestamp":"2017-04-04T22:16:38.472Z" }]}`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) require.NoError(t, err) } @@ -929,13 +929,13 @@ func TestWsTrades(t *testing.T) { require.NoError(t, testexch.Setup(e), "Test instance Setup must not error") e.SetSaveTradeDataStatus(true) msg := []byte(`{"table":"trade","action":"insert","data":[{"timestamp":"2020-02-17T01:35:36.442Z","symbol":"ETHUSD","side":"Sell","size":100,"price":258.3,"tickDirection":"MinusTick","trdMatchID":"c427f7a0-6b26-1e10-5c4e-1bd74daf2a73","grossValue":2583000,"homeNotional":0.9904912836767037,"foreignNotional":255.84389857369254},{"timestamp":"2020-02-17T01:35:36.442Z","symbol":"ETHUSD","side":"Sell","size":100,"price":258.3,"tickDirection":"ZeroMinusTick","trdMatchID":"95eb9155-b58c-70e9-44b7-34efe50302e0","grossValue":2583000,"homeNotional":0.9904912836767037,"foreignNotional":255.84389857369254},{"timestamp":"2020-02-17T01:35:36.442Z","symbol":"ETHUSD","side":"Sell","size":100,"price":258.3,"tickDirection":"ZeroMinusTick","trdMatchID":"e607c187-f25c-86bc-cb39-8afff7aaf2d9","grossValue":2583000,"homeNotional":0.9904912836767037,"foreignNotional":255.84389857369254},{"timestamp":"2020-02-17T01:35:36.442Z","symbol":"ETHUSD","side":"Sell","size":17,"price":258.3,"tickDirection":"ZeroMinusTick","trdMatchID":"0f076814-a57d-9a59-8063-ad6b823a80ac","grossValue":439110,"homeNotional":0.1683835182250396,"foreignNotional":43.49346275752773},{"timestamp":"2020-02-17T01:35:36.442Z","symbol":"ETHUSD","side":"Sell","size":100,"price":258.25,"tickDirection":"MinusTick","trdMatchID":"f4ef3dfd-51c4-538f-37c1-e5071ba1c75d","grossValue":2582500,"homeNotional":0.9904912836767037,"foreignNotional":255.79437400950872},{"timestamp":"2020-02-17T01:35:36.442Z","symbol":"ETHUSD","side":"Sell","size":100,"price":258.25,"tickDirection":"ZeroMinusTick","trdMatchID":"81ef136b-8f4a-b1cf-78a8-fffbfa89bf40","grossValue":2582500,"homeNotional":0.9904912836767037,"foreignNotional":255.79437400950872},{"timestamp":"2020-02-17T01:35:36.442Z","symbol":"ETHUSD","side":"Sell","size":100,"price":258.25,"tickDirection":"ZeroMinusTick","trdMatchID":"65a87e8c-7563-34a4-d040-94e8513c5401","grossValue":2582500,"homeNotional":0.9904912836767037,"foreignNotional":255.79437400950872},{"timestamp":"2020-02-17T01:35:36.442Z","symbol":"ETHUSD","side":"Sell","size":15,"price":258.25,"tickDirection":"ZeroMinusTick","trdMatchID":"1d11a74e-a157-3f33-036d-35a101fba50b","grossValue":387375,"homeNotional":0.14857369255150554,"foreignNotional":38.369156101426306},{"timestamp":"2020-02-17T01:35:36.442Z","symbol":"ETHUSD","side":"Sell","size":1,"price":258.25,"tickDirection":"ZeroMinusTick","trdMatchID":"40d49df1-f018-f66f-4ca5-31d4997641d7","grossValue":25825,"homeNotional":0.009904912836767036,"foreignNotional":2.5579437400950873},{"timestamp":"2020-02-17T01:35:36.442Z","symbol":"ETHUSD","side":"Sell","size":100,"price":258.2,"tickDirection":"MinusTick","trdMatchID":"36135b51-73e5-c007-362b-a55be5830c6b","grossValue":2582000,"homeNotional":0.9904912836767037,"foreignNotional":255.7448494453249},{"timestamp":"2020-02-17T01:35:36.442Z","symbol":"ETHUSD","side":"Sell","size":100,"price":258.2,"tickDirection":"ZeroMinusTick","trdMatchID":"6ee19edb-99aa-3030-ba63-933ffb347ade","grossValue":2582000,"homeNotional":0.9904912836767037,"foreignNotional":255.7448494453249},{"timestamp":"2020-02-17T01:35:36.442Z","symbol":"ETHUSD","side":"Sell","size":100,"price":258.2,"tickDirection":"ZeroMinusTick","trdMatchID":"d44be603-cdb8-d676-e3e2-f91fb12b2a70","grossValue":2582000,"homeNotional":0.9904912836767037,"foreignNotional":255.7448494453249},{"timestamp":"2020-02-17T01:35:36.442Z","symbol":"ETHUSD","side":"Sell","size":5,"price":258.2,"tickDirection":"ZeroMinusTick","trdMatchID":"a14b43b3-50b4-c075-c54d-dfb0165de33d","grossValue":129100,"homeNotional":0.04952456418383518,"foreignNotional":12.787242472266245},{"timestamp":"2020-02-17T01:35:36.442Z","symbol":"ETHUSD","side":"Sell","size":8,"price":258.2,"tickDirection":"ZeroMinusTick","trdMatchID":"3c30e175-5194-320c-8f8c-01636c2f4a32","grossValue":206560,"homeNotional":0.07923930269413629,"foreignNotional":20.45958795562599},{"timestamp":"2020-02-17T01:35:36.442Z","symbol":"ETHUSD","side":"Sell","size":50,"price":258.2,"tickDirection":"ZeroMinusTick","trdMatchID":"5b803378-760b-4919-21fc-bfb275d39ace","grossValue":1291000,"homeNotional":0.49524564183835185,"foreignNotional":127.87242472266244},{"timestamp":"2020-02-17T01:35:36.442Z","symbol":"ETHUSD","side":"Sell","size":244,"price":258.2,"tickDirection":"ZeroMinusTick","trdMatchID":"cf57fec1-c444-b9e5-5e2d-4fb643f4fdb7","grossValue":6300080,"homeNotional":2.416798732171157,"foreignNotional":624.0174326465927}]}`) - require.NoError(t, e.wsHandleData(msg), "Must not error handling a standard stream of trades") + require.NoError(t, e.wsHandleData(t.Context(), msg), "Must not error handling a standard stream of trades") msg = []byte(`{"table":"trade","action":"insert","data":[{"timestamp":"2020-02-17T01:35:36.442Z","symbol":".BGCT","size":14,"price":258.2,"side":"sell"}]}`) - require.ErrorIs(t, e.wsHandleData(msg), exchange.ErrSymbolNotMatched, "Must error correctly with an unknown symbol") + require.ErrorIs(t, e.wsHandleData(t.Context(), msg), exchange.ErrSymbolNotMatched, "Must error correctly with an unknown symbol") msg = []byte(`{"table":"trade","action":"insert","data":[{"timestamp":"2020-02-17T01:35:36.442Z","symbol":".BGCT","size":0,"price":258.2,"side":"sell"}]}`) - require.NoError(t, e.wsHandleData(msg), "Must not error that symbol is unknown when index trade is ignored due to zero size") + require.NoError(t, e.wsHandleData(t.Context(), msg), "Must not error that symbol is unknown when index trade is ignored due to zero size") } func TestGetRecentTrades(t *testing.T) { diff --git a/exchanges/bitmex/bitmex_websocket.go b/exchanges/bitmex/bitmex_websocket.go index 9d78038ff2b..d37d78849b1 100644 --- a/exchanges/bitmex/bitmex_websocket.go +++ b/exchanges/bitmex/bitmex_websocket.go @@ -95,7 +95,7 @@ func (e *Exchange) WsConnect() error { } e.Websocket.Wg.Add(1) - go e.wsReadData() + go e.wsReadData(ctx) if e.Websocket.CanUseAuthenticatedEndpoints() { if err := e.websocketSendAuth(ctx); err != nil { @@ -113,7 +113,7 @@ const ( ) // wsReadData receives and passes on websocket messages for processing -func (e *Exchange) wsReadData() { +func (e *Exchange) wsReadData(ctx context.Context) { defer e.Websocket.Wg.Done() for { @@ -121,14 +121,15 @@ func (e *Exchange) wsReadData() { if resp.Raw == nil { return } - err := e.wsHandleData(resp.Raw) - if err != nil { - e.Websocket.DataHandler <- err + if err := e.wsHandleData(ctx, resp.Raw); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } } } -func (e *Exchange) wsHandleData(respRaw []byte) error { +func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { // We don't need to know about errors, since we're looking optimistically into the json op, _ := jsonparser.GetString(respRaw, "request", "op") errMsg, _ := jsonparser.GetString(respRaw, "error") @@ -199,13 +200,13 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { return nil } - e.Websocket.DataHandler <- announcement.Data + return e.Websocket.DataHandler.Send(ctx, announcement.Data) case bitmexWSAffiliate: var response WsAffiliateResponse if err := json.Unmarshal(respRaw, &response); err != nil { return err } - e.Websocket.DataHandler <- response + return e.Websocket.DataHandler.Send(ctx, response) case bitmexWSInstrument: // ticker case bitmexWSExecution: @@ -222,21 +223,13 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { } oStatus, err := order.StringToOrderStatus(response.Data[i].OrdStatus) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: response.Data[i].OrderID, - Err: err, - } + return err } oSide, err := order.StringToOrderSide(response.Data[i].Side) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: response.Data[i].OrderID, - Err: err, - } + return err } - e.Websocket.DataHandler <- &order.Detail{ + if err := e.Websocket.DataHandler.Send(ctx, &order.Detail{ Exchange: e.Name, OrderID: response.Data[i].OrderID, AccountID: strconv.FormatInt(response.Data[i].Account, 10), @@ -254,6 +247,8 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { IsMaker: false, }, }, + }); err != nil { + return err } } case bitmexWSOrder: @@ -270,29 +265,17 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { } oSide, err := order.StringToOrderSide(response.Data[x].Side) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: response.Data[x].OrderID, - Err: err, - } + return err } oType, err := order.StringToOrderType(response.Data[x].OrderType) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: response.Data[x].OrderID, - Err: err, - } + return err } oStatus, err := order.StringToOrderStatus(response.Data[x].OrderStatus) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: response.Data[x].OrderID, - Err: err, - } + return err } - e.Websocket.DataHandler <- &order.Detail{ + if err := e.Websocket.DataHandler.Send(ctx, &order.Detail{ Price: response.Data[x].Price, Amount: response.Data[x].OrderQuantity, Exchange: e.Name, @@ -304,6 +287,8 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { AssetType: a, Date: response.Data[x].TransactTime, Pair: p, + }); err != nil { + return err } } case "delete": @@ -315,31 +300,19 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { var oSide order.Side oSide, err = order.StringToOrderSide(response.Data[x].Side) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: response.Data[x].OrderID, - Err: err, - } + return err } var oType order.Type oType, err = order.StringToOrderType(response.Data[x].OrderType) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: response.Data[x].OrderID, - Err: err, - } + return err } var oStatus order.Status oStatus, err = order.StringToOrderStatus(response.Data[x].OrderStatus) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: response.Data[x].OrderID, - Err: err, - } + return err } - e.Websocket.DataHandler <- &order.Detail{ + if err := e.Websocket.DataHandler.Send(ctx, &order.Detail{ Price: response.Data[x].Price, Amount: response.Data[x].OrderQuantity, Exchange: e.Name, @@ -351,17 +324,19 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { AssetType: a, Date: response.Data[x].TransactTime, Pair: p, + }); err != nil { + return err } } default: - e.Websocket.DataHandler <- fmt.Errorf("%s - Unsupported order update %+v", e.Name, response) + return e.Websocket.DataHandler.Send(ctx, fmt.Errorf("%s - Unsupported order update %+v", e.Name, response)) } case bitmexWSMargin: var response WsMarginResponse if err := json.Unmarshal(respRaw, &response); err != nil { return err } - e.Websocket.DataHandler <- response + return e.Websocket.DataHandler.Send(ctx, response) case bitmexWSPosition: var response WsPositionResponse if err := json.Unmarshal(respRaw, &response); err != nil { @@ -372,21 +347,21 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { if err := json.Unmarshal(respRaw, &response); err != nil { return err } - e.Websocket.DataHandler <- response + return e.Websocket.DataHandler.Send(ctx, response) case bitmexWSTransact: var response WsTransactResponse if err := json.Unmarshal(respRaw, &response); err != nil { return err } - e.Websocket.DataHandler <- response + return e.Websocket.DataHandler.Send(ctx, response) case bitmexWSWallet: var response WsWalletResponse if err := json.Unmarshal(respRaw, &response); err != nil { return err } - e.Websocket.DataHandler <- response + return e.Websocket.DataHandler.Send(ctx, response) default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)} + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)}) } return nil diff --git a/exchanges/bitstamp/bitstamp_test.go b/exchanges/bitstamp/bitstamp_test.go index eed69389a96..78adc91e88b 100644 --- a/exchanges/bitstamp/bitstamp_test.go +++ b/exchanges/bitstamp/bitstamp_test.go @@ -744,12 +744,12 @@ func TestWsOrderUpdate(t *testing.T) { e := new(Exchange) require.NoError(t, testexch.Setup(e), "Test instance Setup must not error") testexch.FixtureToDataHandler(t, "testdata/wsMyOrders.json", e.wsHandleData) - close(e.Websocket.DataHandler) - assert.Len(t, e.Websocket.DataHandler, 8, "Should see 8 orders") - for resp := range e.Websocket.DataHandler { - switch v := resp.(type) { + e.Websocket.DataHandler.Close() + assert.Len(t, e.Websocket.DataHandler.C, 8, "Should see 8 orders") + for resp := range e.Websocket.DataHandler.C { + switch v := resp.Data.(type) { case *order.Detail: - switch len(e.Websocket.DataHandler) { + switch len(e.Websocket.DataHandler.C) { case 7: assert.Equal(t, "1658864794234880", v.OrderID, "OrderID") assert.Equal(t, time.UnixMicro(1693831262313000), v.Date, "Date") diff --git a/exchanges/bitstamp/bitstamp_websocket.go b/exchanges/bitstamp/bitstamp_websocket.go index 175352b5295..0c867a2ee38 100644 --- a/exchanges/bitstamp/bitstamp_websocket.go +++ b/exchanges/bitstamp/bitstamp_websocket.go @@ -75,7 +75,9 @@ func (e *Exchange) WsConnect() error { }) err = e.seedOrderBook(ctx) if err != nil { - e.Websocket.DataHandler <- err + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } e.Websocket.Wg.Add(1) @@ -94,12 +96,14 @@ func (e *Exchange) wsReadData(ctx context.Context) { return } if err := e.wsHandleData(ctx, resp.Raw); err != nil { - e.Websocket.DataHandler <- err + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } } } -func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { +func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { event, err := jsonparser.GetUnsafeString(respRaw, "event") if err != nil { return fmt.Errorf("%w `event`: %w", common.ErrParsingWSField, err) @@ -116,7 +120,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { case "trade": return e.handleWSTrade(respRaw) case "order_created", "order_deleted", "order_changed": - return e.handleWSOrder(event, respRaw) + return e.handleWSOrder(ctx, event, respRaw) case "request_reconnect": go func() { if err := e.Websocket.Shutdown(); err != nil { // Connection monitor will reconnect @@ -124,7 +128,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { } }() default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)} + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)}) } return nil } @@ -169,7 +173,7 @@ func (e *Exchange) handleWSTrade(msg []byte) error { }) } -func (e *Exchange) handleWSOrder(event string, msg []byte) error { +func (e *Exchange) handleWSOrder(ctx context.Context, event string, msg []byte) error { channel, p, err := e.parseChannelName(msg) if err != nil { return err @@ -221,9 +225,7 @@ func (e *Exchange) handleWSOrder(event string, msg []byte) error { Pair: p, } - e.Websocket.DataHandler <- d - - return nil + return e.Websocket.DataHandler.Send(ctx, d) } func (e *Exchange) generateSubscriptions() (subscription.List, error) { diff --git a/exchanges/btcmarkets/btcmarkets_test.go b/exchanges/btcmarkets/btcmarkets_test.go index ce9c1d7809e..3ecff7da73c 100644 --- a/exchanges/btcmarkets/btcmarkets_test.go +++ b/exchanges/btcmarkets/btcmarkets_test.go @@ -1,6 +1,7 @@ package btcmarkets import ( + "encoding/base64" "log" "os" "testing" @@ -10,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchange/accounts" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" @@ -488,7 +490,7 @@ func TestWSTrade(t *testing.T) { assert.ErrorContains(t, fErrs[0].Err, "WRONG", "Side.UnmarshalJSON errors should propagate correctly") assert.ErrorIs(t, fErrs[1].Err, order.ErrSideIsInvalid, "wsHandleData errors should propagate correctly") assert.ErrorContains(t, fErrs[1].Err, "ANY", "wsHandleData errors should propagate correctly") - close(e.Websocket.DataHandler) + e.Websocket.DataHandler.Close() exp := []trade.Data{ { @@ -512,12 +514,12 @@ func TestWSTrade(t *testing.T) { AssetType: asset.Spot, }, } - require.Len(t, e.Websocket.DataHandler, 2, "Must see correct number of trades") + require.Len(t, e.Websocket.DataHandler.C, 2, "Must see correct number of trades") - for resp := range e.Websocket.DataHandler { - switch v := resp.(type) { + for resp := range e.Websocket.DataHandler.C { + switch v := resp.Data.(type) { case trade.Data: - i := 1 - len(e.Websocket.DataHandler) + i := 1 - len(e.Websocket.DataHandler.C) require.Equalf(t, exp[i], v, "Trade[%d] must be correct", i) case error: t.Error(v) @@ -612,6 +614,10 @@ func TestWsHeartbeats(t *testing.T) { } func TestWsOrders(t *testing.T) { + ctx := accounts.DeployCredentialsToContext(t.Context(), &accounts.Credentials{ + Key: "testkey", + Secret: base64.StdEncoding.EncodeToString([]byte("testsecret")), + }) pressXToJSON := []byte(`{ "orderId": 79003, "marketId": "BTC-AUD", @@ -624,7 +630,7 @@ func TestWsOrders(t *testing.T) { "timestamp": "2019-04-08T20:41:19.339Z", "messageType": "orderChange" }`) - err := e.wsHandleData(t.Context(), pressXToJSON) + err := e.wsHandleData(ctx, pressXToJSON) if err != nil { t.Error(err) } @@ -647,7 +653,7 @@ func TestWsOrders(t *testing.T) { "timestamp": "2019-04-08T20:50:39.658Z", "messageType": "orderChange" }`) - err = e.wsHandleData(t.Context(), pressXToJSON) + err = e.wsHandleData(ctx, pressXToJSON) if err != nil { t.Error(err) } @@ -664,7 +670,7 @@ func TestWsOrders(t *testing.T) { "timestamp": "2019-04-08T20:41:41.857Z", "messageType": "orderChange" }`) - err = e.wsHandleData(t.Context(), pressXToJSON) + err = e.wsHandleData(ctx, pressXToJSON) if err != nil { t.Error(err) } @@ -687,7 +693,7 @@ func TestWsOrders(t *testing.T) { "timestamp": "2019-04-08T20:41:41.857Z", "messageType": "orderChange" }`) - err = e.wsHandleData(t.Context(), pressXToJSON) + err = e.wsHandleData(ctx, pressXToJSON) if err != nil { t.Error(err) } @@ -704,7 +710,7 @@ func TestWsOrders(t *testing.T) { "timestamp": "2019-04-08T20:41:41.857Z", "messageType": "orderChange" }`) - err = e.wsHandleData(t.Context(), pressXToJSON) + err = e.wsHandleData(ctx, pressXToJSON) if err != nil { t.Error(err) } diff --git a/exchanges/btcmarkets/btcmarkets_websocket.go b/exchanges/btcmarkets/btcmarkets_websocket.go index fbd4eeb9cbb..4d9ec3d5142 100644 --- a/exchanges/btcmarkets/btcmarkets_websocket.go +++ b/exchanges/btcmarkets/btcmarkets_websocket.go @@ -78,9 +78,10 @@ func (e *Exchange) wsReadData(ctx context.Context) { if resp.Raw == nil { return } - err := e.wsHandleData(ctx, resp.Raw) - if err != nil { - e.Websocket.DataHandler <- err + if err := e.wsHandleData(ctx, resp.Raw); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } } } @@ -188,7 +189,9 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { } if tradeFeed { - e.Websocket.DataHandler <- td + if err := e.Websocket.DataHandler.Send(ctx, td); err != nil { + return err + } } if saveTradeData { return trade.AddTradesToBuffer(td) @@ -200,7 +203,7 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { return err } - e.Websocket.DataHandler <- &ticker.Price{ + return e.Websocket.DataHandler.Send(ctx, &ticker.Price{ ExchangeName: e.Name, Volume: tick.Volume, High: tick.High24, @@ -211,14 +214,14 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { LastUpdated: tick.Timestamp, AssetType: asset.Spot, Pair: tick.MarketID, - } + }) case fundChange: var transferData WsFundTransfer err := json.Unmarshal(respRaw, &transferData) if err != nil { return err } - e.Websocket.DataHandler <- transferData + return e.Websocket.DataHandler.Send(ctx, transferData) case orderChange: var orderData WsOrderChange err := json.Unmarshal(respRaw, &orderData) @@ -247,41 +250,25 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { } oType, err := order.StringToOrderType(orderData.OrderType) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } + return err } oSide, err := order.StringToOrderSide(orderData.Side) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } + return err } oStatus, err := order.StringToOrderStatus(orderData.Status) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } + return err } clientID := "" if creds, err := e.GetCredentials(ctx); err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } + return err } else if creds != nil { clientID = creds.ClientID } - e.Websocket.DataHandler <- &order.Detail{ + return e.Websocket.DataHandler.Send(ctx, &order.Detail{ Price: price, Amount: originalAmount, RemainingAmount: orderData.OpenVolume, @@ -295,7 +282,7 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { Date: orderData.Timestamp, Trades: trades, Pair: orderData.MarketID, - } + }) case "error": var wsErr WsError err := json.Unmarshal(respRaw, &wsErr) @@ -304,8 +291,7 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { } return fmt.Errorf("%v websocket error. Code: %v Message: %v", e.Name, wsErr.Code, wsErr.Message) default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)} - return nil + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)}) } return nil } diff --git a/exchanges/btse/btse_test.go b/exchanges/btse/btse_test.go index 31057c687c9..680fda06718 100644 --- a/exchanges/btse/btse_test.go +++ b/exchanges/btse/btse_test.go @@ -471,7 +471,7 @@ func TestWSTrades(t *testing.T) { e := new(Exchange) require.NoError(t, testexch.Setup(e), "Setup Instance must not error") testexch.FixtureToDataHandler(t, "testdata/wsAllTrades.json", e.wsHandleData) - close(e.Websocket.DataHandler) + e.Websocket.DataHandler.Close() exp := []trade.Data{ { @@ -495,11 +495,11 @@ func TestWSTrades(t *testing.T) { AssetType: asset.Spot, }, } - require.Len(t, e.Websocket.DataHandler, 2, "Must see the correct number of trades") - for resp := range e.Websocket.DataHandler { - switch v := resp.(type) { + require.Len(t, e.Websocket.DataHandler.C, 2, "Must see the correct number of trades") + for resp := range e.Websocket.DataHandler.C { + switch v := resp.Data.(type) { case trade.Data: - i := 1 - len(e.Websocket.DataHandler) + i := 1 - len(e.Websocket.DataHandler.C) require.Equalf(t, exp[i], v, "Trade [%d] must be correct", i) case error: t.Error(v) diff --git a/exchanges/btse/btse_websocket.go b/exchanges/btse/btse_websocket.go index 22c545fc55b..2f71d427e7e 100644 --- a/exchanges/btse/btse_websocket.go +++ b/exchanges/btse/btse_websocket.go @@ -61,8 +61,8 @@ func (e *Exchange) WsConnect() error { if e.IsWebsocketAuthenticationSupported() { err = e.WsAuthenticate(ctx) if err != nil { - e.Websocket.DataHandler <- err e.Websocket.SetCanUseAuthenticatedEndpoints(false) + return err } } @@ -120,14 +120,15 @@ func (e *Exchange) wsReadData(ctx context.Context) { if resp.Raw == nil { return } - err := e.wsHandleData(ctx, resp.Raw) - if err != nil { - e.Websocket.DataHandler <- err + if err := e.wsHandleData(ctx, resp.Raw); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } } } -func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { +func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { type Result map[string]any var result Result err := json.Unmarshal(respRaw, &result) @@ -189,27 +190,16 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { var oStatus order.Status oType, err = order.StringToOrderType(notification.Data[i].Type) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: notification.Data[i].OrderID, - Err: err, - } + return err } - oSide, err = order.StringToOrderSide(notification.Data[i].OrderMode) + + oSide, err = order.StringToOrderSide(strings.ReplaceAll(notification.Data[i].OrderMode, "MODE_", "")) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: notification.Data[i].OrderID, - Err: err, - } + return err } oStatus, err = stringToOrderStatus(notification.Data[i].Status) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: notification.Data[i].OrderID, - Err: err, - } + return err } var p currency.Pair @@ -224,7 +214,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { return err } - e.Websocket.DataHandler <- &order.Detail{ + if err := e.Websocket.DataHandler.Send(ctx, &order.Detail{ Price: notification.Data[i].Price, Amount: notification.Data[i].Size, TriggerPrice: notification.Data[i].TriggerPrice, @@ -236,6 +226,8 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { AssetType: a, Date: notification.Data[i].Timestamp.Time(), Pair: p, + }); err != nil { + return err } } case strings.Contains(topic, "tradeHistoryApi"): @@ -275,7 +267,9 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { } if tradeFeed { for i := range trades { - e.Websocket.DataHandler <- trades[i] + if err := e.Websocket.DataHandler.Send(ctx, trades[i]); err != nil { + return err + } } } if saveTradeData { diff --git a/exchanges/bybit/bybit_test.go b/exchanges/bybit/bybit_test.go index e260d7f415c..1ed20670579 100644 --- a/exchanges/bybit/bybit_test.go +++ b/exchanges/bybit/bybit_test.go @@ -3020,7 +3020,7 @@ func TestWSHandleData(t *testing.T) { keys := slices.Collect(maps.Keys(pushDataMap)) slices.Sort(keys) for x := range keys { - err := e.wsHandleData(nil, asset.Spot, []byte(pushDataMap[keys[x]])) + err := e.wsHandleData(t.Context(), nil, asset.Spot, []byte(pushDataMap[keys[x]])) if keys[x] == "unhandled" { assert.ErrorIs(t, err, errUnhandledStreamData, "wsHandleData should error correctly for unhandled topics") } else { @@ -3054,15 +3054,15 @@ func TestWSHandleAuthenticatedData(t *testing.T) { } return e.wsHandleAuthenticatedData(ctx, &FixtureConnection{match: websocket.NewMatch()}, r) }) - close(e.Websocket.DataHandler) - require.Len(t, e.Websocket.DataHandler, 6, "Should see correct number of messages") + e.Websocket.DataHandler.Close() + require.Len(t, e.Websocket.DataHandler.C, 6, "Should see correct number of messages") require.Len(t, fErrs, 1, "Must get exactly one error message") assert.ErrorContains(t, fErrs[0].Err, "cannot save holdings: nil pointer: *accounts.Accounts") i := 0 - for data := range e.Websocket.DataHandler { + for data := range e.Websocket.DataHandler.C { i++ - switch v := data.(type) { + switch v := data.Data.(type) { case WsPositions: require.Len(t, v, 1, "must see 1 position") assert.Zero(t, v[0].PositionIdx, "PositionIdx should be 0") @@ -3194,16 +3194,16 @@ func TestWsTicker(t *testing.T) { } testexch.FixtureToDataHandler(t, "testdata/wsTicker.json", func(_ context.Context, r []byte) error { defer slices.Delete(assetRouting, 0, 1) - return e.wsHandleData(nil, assetRouting[0], r) + return e.wsHandleData(t.Context(), nil, assetRouting[0], r) }) - close(e.Websocket.DataHandler) + e.Websocket.DataHandler.Close() expected := 8 - require.Len(t, e.Websocket.DataHandler, expected, "Should see correct number of tickers") - for resp := range e.Websocket.DataHandler { - switch v := resp.(type) { + require.Len(t, e.Websocket.DataHandler.C, expected, "Should see correct number of tickers") + for resp := range e.Websocket.DataHandler.C { + switch v := resp.Data.(type) { case *ticker.Price: assert.Equal(t, e.Name, v.ExchangeName, "ExchangeName should be correct") - switch expected - len(e.Websocket.DataHandler) { + switch expected - len(e.Websocket.DataHandler.C) { case 1: // Spot assert.Equal(t, currency.BTC, v.Pair.Base, "Pair base should be correct") assert.Equal(t, currency.USDT, v.Pair.Quote, "Pair quote should be correct") @@ -3443,7 +3443,7 @@ func TestFetchTradablePairs(t *testing.T) { func TestDeltaUpdateOrderbook(t *testing.T) { t.Parallel() data := []byte(`{"topic":"orderbook.50.WEMIXUSDT","ts":1697573183768,"type":"snapshot","data":{"s":"WEMIXUSDT","b":[["0.9511","260.703"],["0.9677","0"]],"a":[],"u":3119516,"seq":14126848493},"cts":1728966699481}`) - err := e.wsHandleData(nil, asset.Spot, data) + err := e.wsHandleData(t.Context(), nil, asset.Spot, data) require.NoError(t, err, "wsHandleData must not error") update := []byte(`{"topic":"orderbook.50.WEMIXUSDT","ts":1697573183768,"type":"delta","data":{"s":"WEMIXUSDT","b":[["0.9511","260.703"],["0.9677","0"]],"a":[],"u":3119516,"seq":14126848493},"cts":1728966699481}`) var wsResponse WebsocketResponse @@ -3883,7 +3883,7 @@ func TestHandleNoTopicWebsocketResponse(t *testing.T) { } { t.Run(fmt.Sprintf("operation: %s, requestID: %s", tc.operation, tc.requestID), func(t *testing.T) { t.Parallel() - err := e.handleNoTopicWebsocketResponse(&FixtureConnection{match: websocket.NewMatch()}, &WebsocketResponse{Operation: tc.operation, RequestID: tc.requestID}, nil) + err := e.handleNoTopicWebsocketResponse(t.Context(), &FixtureConnection{match: websocket.NewMatch()}, &WebsocketResponse{Operation: tc.operation, RequestID: tc.requestID}, nil) assert.ErrorIs(t, err, tc.error, "handleNoTopicWebsocketResponse should return expected error") }) } diff --git a/exchanges/bybit/bybit_websocket.go b/exchanges/bybit/bybit_websocket.go index a247e3467be..e2b2c56f13f 100644 --- a/exchanges/bybit/bybit_websocket.go +++ b/exchanges/bybit/bybit_websocket.go @@ -228,13 +228,13 @@ func (e *Exchange) wsHandleTradeData(conn websocket.Connection, respRaw []byte) } } -func (e *Exchange) wsHandleData(conn websocket.Connection, assetType asset.Item, respRaw []byte) error { +func (e *Exchange) wsHandleData(ctx context.Context, conn websocket.Connection, assetType asset.Item, respRaw []byte) error { var result WebsocketResponse if err := json.Unmarshal(respRaw, &result); err != nil { return err } if result.Topic == "" { - return e.handleNoTopicWebsocketResponse(conn, &result, respRaw) + return e.handleNoTopicWebsocketResponse(ctx, conn, &result, respRaw) } topicSplit := strings.Split(result.Topic, ".") switch topicSplit[0] { @@ -243,17 +243,17 @@ func (e *Exchange) wsHandleData(conn websocket.Connection, assetType asset.Item, case chanPublicTrade: return e.wsProcessPublicTrade(assetType, &result) case chanPublicTicker: - return e.wsProcessPublicTicker(assetType, &result) + return e.wsProcessPublicTicker(ctx, assetType, &result) case chanKline: - return e.wsProcessKline(assetType, &result, topicSplit) + return e.wsProcessKline(ctx, assetType, &result, topicSplit) case chanLiquidation: - return e.wsProcessLiquidation(&result) + return e.wsProcessLiquidation(ctx, &result) case chanLeverageTokenKline: - return e.wsProcessLeverageTokenKline(assetType, &result, topicSplit) + return e.wsProcessLeverageTokenKline(ctx, assetType, &result, topicSplit) case chanLeverageTokenTicker: - return e.wsProcessLeverageTokenTicker(assetType, &result) + return e.wsProcessLeverageTokenTicker(ctx, assetType, &result) case chanLeverageTokenNav: - return e.wsLeverageTokenNav(&result) + return e.wsLeverageTokenNav(ctx, &result) } return fmt.Errorf("%w %s", errUnhandledStreamData, string(respRaw)) } @@ -264,14 +264,14 @@ func (e *Exchange) wsHandleAuthenticatedData(ctx context.Context, conn websocket return err } if result.Topic == "" { - return e.handleNoTopicWebsocketResponse(conn, &result, respRaw) + return e.handleNoTopicWebsocketResponse(ctx, conn, &result, respRaw) } topicSplit := strings.Split(result.Topic, ".") switch topicSplit[0] { case chanPositions: - return e.wsProcessPosition(&result) + return e.wsProcessPosition(ctx, &result) case chanExecution: - return e.wsProcessExecution(&result) + return e.wsProcessExecution(ctx, &result) case chanOrder: // Use first order's orderLinkId to match with an entire batch of order change requests if id, err := jsonparser.GetString(respRaw, "data", "[0]", "orderLinkId"); err == nil { @@ -279,16 +279,16 @@ func (e *Exchange) wsHandleAuthenticatedData(ctx context.Context, conn websocket return nil // If the data has been routed, return } } - return e.wsProcessOrder(&result) + return e.wsProcessOrder(ctx, &result) case chanWallet: return e.wsProcessWalletPushData(ctx, respRaw) case chanGreeks: - return e.wsProcessGreeks(respRaw) + return e.wsProcessGreeks(ctx, respRaw) } return fmt.Errorf("%w %s", errUnhandledStreamData, string(respRaw)) } -func (e *Exchange) handleNoTopicWebsocketResponse(conn websocket.Connection, result *WebsocketResponse, respRaw []byte) error { +func (e *Exchange) handleNoTopicWebsocketResponse(ctx context.Context, conn websocket.Connection, result *WebsocketResponse, respRaw []byte) error { switch result.Operation { case "subscribe", "unsubscribe", "auth": if result.RequestID != "" { @@ -296,18 +296,17 @@ func (e *Exchange) handleNoTopicWebsocketResponse(conn websocket.Connection, res } case "ping", "pong": default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: string(respRaw)} + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{Message: string(respRaw)}) } return nil } -func (e *Exchange) wsProcessGreeks(resp []byte) error { +func (e *Exchange) wsProcessGreeks(ctx context.Context, resp []byte) error { var result GreeksResponse if err := json.Unmarshal(resp, &result); err != nil { return err } - e.Websocket.DataHandler <- &result - return nil + return e.Websocket.DataHandler.Send(ctx, &result) } func (e *Exchange) wsProcessWalletPushData(ctx context.Context, resp []byte) error { @@ -328,12 +327,11 @@ func (e *Exchange) wsProcessWalletPushData(ctx context.Context, resp []byte) err if err := e.Accounts.Save(ctx, subAccts, false); err != nil { return err } - e.Websocket.DataHandler <- subAccts - return nil + return e.Websocket.DataHandler.Send(ctx, subAccts) } // wsProcessOrder the order stream to see changes to your orders in real-time. -func (e *Exchange) wsProcessOrder(resp *WebsocketResponse) error { +func (e *Exchange) wsProcessOrder(ctx context.Context, resp *WebsocketResponse) error { var result []WebsocketOrderDetails if err := json.Unmarshal(resp.Data, &result); err != nil { return err @@ -372,11 +370,10 @@ func (e *Exchange) wsProcessOrder(resp *WebsocketResponse) error { LastUpdated: result[x].UpdatedTime.Time(), } } - e.Websocket.DataHandler <- execution - return nil + return e.Websocket.DataHandler.Send(ctx, execution) } -func (e *Exchange) wsProcessExecution(resp *WebsocketResponse) error { +func (e *Exchange) wsProcessExecution(ctx context.Context, resp *WebsocketResponse) error { var result WsExecutions if err := json.Unmarshal(resp.Data, &result); err != nil { return err @@ -404,29 +401,26 @@ func (e *Exchange) wsProcessExecution(resp *WebsocketResponse) error { Amount: result[x].ExecQty.Float64(), } } - e.Websocket.DataHandler <- executions - return nil + return e.Websocket.DataHandler.Send(ctx, executions) } -func (e *Exchange) wsProcessPosition(resp *WebsocketResponse) error { +func (e *Exchange) wsProcessPosition(ctx context.Context, resp *WebsocketResponse) error { var result WsPositions if err := json.Unmarshal(resp.Data, &result); err != nil { return err } - e.Websocket.DataHandler <- result - return nil + return e.Websocket.DataHandler.Send(ctx, result) } -func (e *Exchange) wsLeverageTokenNav(resp *WebsocketResponse) error { +func (e *Exchange) wsLeverageTokenNav(ctx context.Context, resp *WebsocketResponse) error { var result LTNav if err := json.Unmarshal(resp.Data, &result); err != nil { return err } - e.Websocket.DataHandler <- result - return nil + return e.Websocket.DataHandler.Send(ctx, result) } -func (e *Exchange) wsProcessLeverageTokenTicker(assetType asset.Item, resp *WebsocketResponse) error { +func (e *Exchange) wsProcessLeverageTokenTicker(ctx context.Context, assetType asset.Item, resp *WebsocketResponse) error { var result TickerWebsocket if err := json.Unmarshal(resp.Data, &result); err != nil { return err @@ -435,7 +429,7 @@ func (e *Exchange) wsProcessLeverageTokenTicker(assetType asset.Item, resp *Webs if err != nil { return err } - e.Websocket.DataHandler <- &ticker.Price{ + return e.Websocket.DataHandler.Send(ctx, &ticker.Price{ Last: result.LastPrice.Float64(), High: result.HighPrice24H.Float64(), Low: result.LowPrice24H.Float64(), @@ -443,11 +437,10 @@ func (e *Exchange) wsProcessLeverageTokenTicker(assetType asset.Item, resp *Webs ExchangeName: e.Name, AssetType: assetType, LastUpdated: resp.PushTimestamp.Time(), - } - return nil + }) } -func (e *Exchange) wsProcessLeverageTokenKline(assetType asset.Item, resp *WebsocketResponse, topicSplit []string) error { +func (e *Exchange) wsProcessLeverageTokenKline(ctx context.Context, assetType asset.Item, resp *WebsocketResponse, topicSplit []string) error { var result LTKlines if err := json.Unmarshal(resp.Data, &result); err != nil { return err @@ -476,20 +469,18 @@ func (e *Exchange) wsProcessLeverageTokenKline(assetType asset.Item, resp *Webso LowPrice: result[x].Low.Float64(), } } - e.Websocket.DataHandler <- result - return nil + return e.Websocket.DataHandler.Send(ctx, ltKline) } -func (e *Exchange) wsProcessLiquidation(resp *WebsocketResponse) error { +func (e *Exchange) wsProcessLiquidation(ctx context.Context, resp *WebsocketResponse) error { var result WebsocketLiquidation if err := json.Unmarshal(resp.Data, &result); err != nil { return err } - e.Websocket.DataHandler <- result - return nil + return e.Websocket.DataHandler.Send(ctx, result) } -func (e *Exchange) wsProcessKline(assetType asset.Item, resp *WebsocketResponse, topicSplit []string) error { +func (e *Exchange) wsProcessKline(ctx context.Context, assetType asset.Item, resp *WebsocketResponse, topicSplit []string) error { var result WsKlines if err := json.Unmarshal(resp.Data, &result); err != nil { return err @@ -519,11 +510,10 @@ func (e *Exchange) wsProcessKline(assetType asset.Item, resp *WebsocketResponse, Volume: result[x].Volume.Float64(), } } - e.Websocket.DataHandler <- spotCandlesticks - return nil + return e.Websocket.DataHandler.Send(ctx, spotCandlesticks) } -func (e *Exchange) wsProcessPublicTicker(assetType asset.Item, resp *WebsocketResponse) error { +func (e *Exchange) wsProcessPublicTicker(ctx context.Context, assetType asset.Item, resp *WebsocketResponse) error { var tickResp TickerWebsocket if err := json.Unmarshal(resp.Data, &tickResp); err != nil { return err @@ -547,8 +537,7 @@ func (e *Exchange) wsProcessPublicTicker(assetType asset.Item, resp *WebsocketRe if err := ticker.ProcessTicker(tick); err != nil { return err } - e.Websocket.DataHandler <- tick - return nil + return e.Websocket.DataHandler.Send(ctx, tick) } func updateTicker(tick *ticker.Price, resp *TickerWebsocket) { diff --git a/exchanges/bybit/bybit_websocket_requests_test.go b/exchanges/bybit/bybit_websocket_requests_test.go index d810b6699c6..40c4bb915f4 100644 --- a/exchanges/bybit/bybit_websocket_requests_test.go +++ b/exchanges/bybit/bybit_websocket_requests_test.go @@ -205,6 +205,6 @@ func getWebsocketInstance(t *testing.T) *Exchange { require.NoError(t, e.Setup(bConf), "Setup must not error") e.CurrencyPairs.Load(pairs) - require.NoError(t, e.Websocket.Connect()) + require.NoError(t, e.Websocket.Connect(t.Context())) return e } diff --git a/exchanges/bybit/bybit_wrapper.go b/exchanges/bybit/bybit_wrapper.go index a0cb3340d0d..9639d512b2b 100644 --- a/exchanges/bybit/bybit_wrapper.go +++ b/exchanges/bybit/bybit_wrapper.go @@ -252,8 +252,8 @@ func (e *Exchange) Setup(exch *config.Exchange) error { GenerateSubscriptions: e.generateSubscriptions, Subscriber: e.SpotSubscribe, Unsubscriber: e.SpotUnsubscribe, - Handler: func(_ context.Context, conn websocket.Connection, resp []byte) error { - return e.wsHandleData(conn, asset.Spot, resp) + Handler: func(ctx context.Context, conn websocket.Connection, resp []byte) error { + return e.wsHandleData(ctx, conn, asset.Spot, resp) }, }); err != nil { return err @@ -273,8 +273,8 @@ func (e *Exchange) Setup(exch *config.Exchange) error { GenerateSubscriptions: e.GenerateOptionsDefaultSubscriptions, Subscriber: e.OptionsSubscribe, Unsubscriber: e.OptionsUnsubscribe, - Handler: func(_ context.Context, conn websocket.Connection, resp []byte) error { - return e.wsHandleData(conn, asset.Options, resp) + Handler: func(ctx context.Context, conn websocket.Connection, resp []byte) error { + return e.wsHandleData(ctx, conn, asset.Options, resp) }, }); err != nil { return err @@ -300,8 +300,8 @@ func (e *Exchange) Setup(exch *config.Exchange) error { Unsubscriber: func(ctx context.Context, conn websocket.Connection, unsub subscription.List) error { return e.LinearUnsubscribe(ctx, conn, asset.USDTMarginedFutures, unsub) }, - Handler: func(_ context.Context, conn websocket.Connection, resp []byte) error { - return e.wsHandleData(conn, asset.USDTMarginedFutures, resp) + Handler: func(ctx context.Context, conn websocket.Connection, resp []byte) error { + return e.wsHandleData(ctx, conn, asset.USDTMarginedFutures, resp) }, MessageFilter: asset.USDTMarginedFutures, // Unused but it allows us to differentiate between the two linear futures types. }); err != nil { @@ -328,8 +328,8 @@ func (e *Exchange) Setup(exch *config.Exchange) error { Unsubscriber: func(ctx context.Context, conn websocket.Connection, unsub subscription.List) error { return e.LinearUnsubscribe(ctx, conn, asset.USDCMarginedFutures, unsub) }, - Handler: func(_ context.Context, conn websocket.Connection, resp []byte) error { - return e.wsHandleData(conn, asset.USDCMarginedFutures, resp) + Handler: func(ctx context.Context, conn websocket.Connection, resp []byte) error { + return e.wsHandleData(ctx, conn, asset.USDCMarginedFutures, resp) }, MessageFilter: asset.USDCMarginedFutures, // Unused but it allows us to differentiate between the two linear futures types. }); err != nil { @@ -350,8 +350,8 @@ func (e *Exchange) Setup(exch *config.Exchange) error { GenerateSubscriptions: e.GenerateInverseDefaultSubscriptions, Subscriber: e.InverseSubscribe, Unsubscriber: e.InverseUnsubscribe, - Handler: func(_ context.Context, conn websocket.Connection, resp []byte) error { - return e.wsHandleData(conn, asset.CoinMarginedFutures, resp) + Handler: func(ctx context.Context, conn websocket.Connection, resp []byte) error { + return e.wsHandleData(ctx, conn, asset.CoinMarginedFutures, resp) }, }); err != nil { return err diff --git a/exchanges/coinbase/coinbase_test.go b/exchanges/coinbase/coinbase_test.go index 5b67f5b75a8..469bbb68fa3 100644 --- a/exchanges/coinbase/coinbase_test.go +++ b/exchanges/coinbase/coinbase_test.go @@ -111,7 +111,7 @@ func TestWsConnect(t *testing.T) { assert.ErrorIs(t, err, websocket.ErrWebsocketNotEnabled) err = exchangeBaseHelper(exch) require.NoError(t, err) - err = exch.Websocket.Enable() + err = exch.Websocket.Enable(t.Context()) assert.NoError(t, err) } @@ -1515,7 +1515,7 @@ func TestWsAuth(t *testing.T) { err := e.Websocket.Conn.Dial(t.Context(), &dialer, http.Header{}) require.NoError(t, err) e.Websocket.Wg.Add(1) - go e.wsReadData() + go e.wsReadData(t.Context()) err = e.Subscribe(subscription.List{ { Channel: "myAccount", @@ -1527,7 +1527,7 @@ func TestWsAuth(t *testing.T) { assert.NoError(t, err) timer := time.NewTimer(sharedtestvalues.WebsocketResponseDefaultTimeout) select { - case badResponse := <-e.Websocket.DataHandler: + case badResponse := <-e.Websocket.DataHandler.C: assert.IsType(t, []order.Detail{}, badResponse) case <-timer.C: } @@ -1542,64 +1542,64 @@ func TestWsHandleData(t *testing.T) { go func() { for { select { - case <-e.Websocket.DataHandler: + case <-e.Websocket.DataHandler.C: continue case <-done: return } } }() - _, err := e.wsHandleData(nil) + _, err := e.wsHandleData(t.Context(), nil) var syntaxErr *json.SyntaxError assert.True(t, errors.As(err, &syntaxErr) || strings.Contains(err.Error(), "Syntax error no sources available, the input json is empty"), errJSONUnmarshalUnexpected) mockJSON := []byte(`{"type": "error"}`) - _, err = e.wsHandleData(mockJSON) + _, err = e.wsHandleData(t.Context(), mockJSON) assert.Error(t, err) mockJSON = []byte(`{"sequence_num": 0, "channel": "subscriptions"}`) - _, err = e.wsHandleData(mockJSON) + _, err = e.wsHandleData(t.Context(), mockJSON) assert.NoError(t, err) var unmarshalTypeErr *json.UnmarshalTypeError mockJSON = []byte(`{"sequence_num": 0, "channel": "status", "events": [{"type": 1234}]}`) - _, err = e.wsHandleData(mockJSON) + _, err = e.wsHandleData(t.Context(), mockJSON) assert.True(t, errors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) mockJSON = []byte(`{"sequence_num": 0, "channel": "status", "events": [{"type": "moo"}]}`) - _, err = e.wsHandleData(mockJSON) + _, err = e.wsHandleData(t.Context(), mockJSON) assert.NoError(t, err) mockJSON = []byte(`{"sequence_num": 0, "channel": "ticker", "events": [{"type": "moo", "tickers": false}]}`) - _, err = e.wsHandleData(mockJSON) + _, err = e.wsHandleData(t.Context(), mockJSON) assert.True(t, errors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) mockJSON = []byte(`{"sequence_num": 0, "channel": "candles", "events": [{"type": false}]}`) - _, err = e.wsHandleData(mockJSON) + _, err = e.wsHandleData(t.Context(), mockJSON) assert.True(t, errors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) mockJSON = []byte(`{"sequence_num": 0, "channel": "candles", "events": [{"type": "moo", "candles": [{"low": "1.1"}]}]}`) - _, err = e.wsHandleData(mockJSON) + _, err = e.wsHandleData(t.Context(), mockJSON) assert.NoError(t, err) mockJSON = []byte(`{"sequence_num": 0, "channel": "market_trades", "events": [{"type": false}]}`) - _, err = e.wsHandleData(mockJSON) + _, err = e.wsHandleData(t.Context(), mockJSON) assert.True(t, errors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) mockJSON = []byte(`{"sequence_num": 0, "channel": "market_trades", "events": [{"type": "moo", "trades": [{"price": "1.1"}]}]}`) - _, err = e.wsHandleData(mockJSON) + _, err = e.wsHandleData(t.Context(), mockJSON) assert.NoError(t, err) mockJSON = []byte(`{"sequence_num": 0, "channel": "l2_data", "events": [{"type": false, "updates": [{"price_level": "1.1"}]}]}`) - _, err = e.wsHandleData(mockJSON) + _, err = e.wsHandleData(t.Context(), mockJSON) assert.True(t, errors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) mockJSON = []byte(`{"sequence_num": 0, "channel": "l2_data", "timestamp": "2006-01-02T15:04:05Z", "events": [{"type": "moo", "updates": [{"price_level": "1.1"}]}]}`) - _, err = e.wsHandleData(mockJSON) + _, err = e.wsHandleData(t.Context(), mockJSON) assert.ErrorIs(t, err, errUnknownL2DataType) mockJSON = []byte(`{"sequence_num": 0, "channel": "l2_data", "timestamp": "2006-01-02T15:04:05Z", "events": [{"type": "snapshot", "product_id": "BTC-USD", "updates": [{"side": "bid", "price_level": "1.1", "new_quantity": "2.2"}]}]}`) - _, err = e.wsHandleData(mockJSON) + _, err = e.wsHandleData(t.Context(), mockJSON) assert.NoError(t, err) mockJSON = []byte(`{"sequence_num": 0, "channel": "l2_data", "timestamp": "2006-01-02T15:04:05Z", "events": [{"type": "update", "product_id": "BTC-USD", "updates": [{"side": "bid", "price_level": "1.1", "new_quantity": "2.2"}]}]}`) - _, err = e.wsHandleData(mockJSON) + _, err = e.wsHandleData(t.Context(), mockJSON) assert.NoError(t, err) mockJSON = []byte(`{"sequence_num": 0, "channel": "user", "events": [{"type": false}]}`) - _, err = e.wsHandleData(mockJSON) + _, err = e.wsHandleData(t.Context(), mockJSON) assert.True(t, errors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) - mockJSON = []byte(`{"sequence_num": 0, "channel": "user", "events": [{"type": "moo", "orders": [{"limit_price": "2.2", "total_fees": "1.1", "post_only": true}], "positions": {"perpetual_futures_positions": [{"margin_type": "fakeMarginType"}], "expiring_futures_positions": [{}]}}]}`) - _, err = e.wsHandleData(mockJSON) - assert.NoError(t, err) + mockJSON = []byte(`{"sequence_num": 0, "channel": "user", "events": [{"type": "l", "orders": [{"limit_price": "2.2", "total_fees": "1.1", "post_only": true}], "positions": {"perpetual_futures_positions": [{"margin_type": "fakeMarginType"}], "expiring_futures_positions": [{}]}}]}`) + _, err = e.wsHandleData(t.Context(), mockJSON) + assert.ErrorIs(t, err, order.ErrUnrecognisedOrderType) mockJSON = []byte(`{"sequence_num": 0, "channel": "fakechan", "events": [{"type": ""}]}`) - _, err = e.wsHandleData(mockJSON) + _, err = e.wsHandleData(t.Context(), mockJSON) assert.ErrorIs(t, err, errChannelNameUnknown) p, err := e.FormatExchangeCurrency(currency.NewBTCUSD(), asset.Spot) require.NoError(t, err) @@ -1607,7 +1607,7 @@ func TestWsHandleData(t *testing.T) { p: {p}, }) mockJSON = []byte(`{"sequence_num": 0, "channel": "ticker", "events": [{"type": "moo", "tickers": [{"product_id": "BTC-USD", "price": "1.1"}]}]}`) - _, err = e.wsHandleData(mockJSON) + _, err = e.wsHandleData(t.Context(), mockJSON) assert.NoError(t, err) } diff --git a/exchanges/coinbase/coinbase_websocket.go b/exchanges/coinbase/coinbase_websocket.go index 7dec8d4083a..fdf64299112 100644 --- a/exchanges/coinbase/coinbase_websocket.go +++ b/exchanges/coinbase/coinbase_websocket.go @@ -21,6 +21,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/log" ) const ( @@ -66,12 +67,12 @@ func (e *Exchange) WsConnect() error { return err } e.Websocket.Wg.Add(1) - go e.wsReadData() + go e.wsReadData(ctx) return nil } // wsReadData receives and passes on websocket messages for processing -func (e *Exchange) wsReadData() { +func (e *Exchange) wsReadData(ctx context.Context) { defer e.Websocket.Wg.Done() var seqCount uint64 for { @@ -79,13 +80,18 @@ func (e *Exchange) wsReadData() { if resp.Raw == nil { return } - sequence, err := e.wsHandleData(resp.Raw) + sequence, err := e.wsHandleData(ctx, resp.Raw) if err != nil { - e.Websocket.DataHandler <- err + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } if sequence != nil { if *sequence != seqCount { - e.Websocket.DataHandler <- fmt.Errorf("%w: received %v, expected %v", errOutOfSequence, sequence, seqCount) + err := fmt.Errorf("%w: received %v, expected %v", errOutOfSequence, sequence, seqCount) + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } seqCount = *sequence } seqCount++ @@ -94,7 +100,7 @@ func (e *Exchange) wsReadData() { } // wsProcessTicker handles ticker data from the websocket -func (e *Exchange) wsProcessTicker(resp *StandardWebsocketResponse) error { +func (e *Exchange) wsProcessTicker(ctx context.Context, resp *StandardWebsocketResponse) error { var wsTickers []WebsocketTickerHolder if err := json.Unmarshal(resp.Events, &wsTickers); err != nil { return err @@ -129,12 +135,11 @@ func (e *Exchange) wsProcessTicker(resp *StandardWebsocketResponse) error { } } } - e.Websocket.DataHandler <- allTickers - return nil + return e.Websocket.DataHandler.Send(ctx, allTickers) } // wsProcessCandle handles candle data from the websocket -func (e *Exchange) wsProcessCandle(resp *StandardWebsocketResponse) error { +func (e *Exchange) wsProcessCandle(ctx context.Context, resp *StandardWebsocketResponse) error { var wsCandles []WebsocketCandleHolder if err := json.Unmarshal(resp.Events, &wsCandles); err != nil { return err @@ -156,12 +161,11 @@ func (e *Exchange) wsProcessCandle(resp *StandardWebsocketResponse) error { }) } } - e.Websocket.DataHandler <- allCandles - return nil + return e.Websocket.DataHandler.Send(ctx, allCandles) } // wsProcessMarketTrades handles market trades data from the websocket -func (e *Exchange) wsProcessMarketTrades(resp *StandardWebsocketResponse) error { +func (e *Exchange) wsProcessMarketTrades(ctx context.Context, resp *StandardWebsocketResponse) error { var wsTrades []WebsocketMarketTradeHolder if err := json.Unmarshal(resp.Events, &wsTrades); err != nil { return err @@ -181,8 +185,7 @@ func (e *Exchange) wsProcessMarketTrades(resp *StandardWebsocketResponse) error }) } } - e.Websocket.DataHandler <- allTrades - return nil + return e.Websocket.DataHandler.Send(ctx, allTrades) } // wsProcessL2 handles l2 orderbook data from the websocket @@ -209,7 +212,7 @@ func (e *Exchange) wsProcessL2(resp *StandardWebsocketResponse) error { } // wsProcessUser handles user data from the websocket -func (e *Exchange) wsProcessUser(resp *StandardWebsocketResponse) error { +func (e *Exchange) wsProcessUser(ctx context.Context, resp *StandardWebsocketResponse) error { var wsUser []WebsocketOrderDataHolder err := json.Unmarshal(resp.Events, &wsUser) if err != nil { @@ -220,24 +223,15 @@ func (e *Exchange) wsProcessUser(resp *StandardWebsocketResponse) error { for j := range wsUser[i].Orders { var oType order.Type if oType, err = stringToStandardType(wsUser[i].Orders[j].OrderType); err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - Err: err, - } + return err } var oSide order.Side if oSide, err = order.StringToOrderSide(wsUser[i].Orders[j].OrderSide); err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - Err: err, - } + return err } var oStatus order.Status if oStatus, err = statusToStandardStatus(wsUser[i].Orders[j].Status); err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - Err: err, - } + return err } price := wsUser[i].Orders[j].AveragePrice if wsUser[i].Orders[j].LimitPrice != 0 { @@ -245,17 +239,11 @@ func (e *Exchange) wsProcessUser(resp *StandardWebsocketResponse) error { } var assetType asset.Item if assetType, err = stringToStandardAsset(wsUser[i].Orders[j].ProductType); err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - Err: err, - } + return err } var tif order.TimeInForce if tif, err = strategyDecoder(wsUser[i].Orders[j].TimeInForce); err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - Err: err, - } + return err } if wsUser[i].Orders[j].PostOnly { tif |= order.PostOnly @@ -283,17 +271,11 @@ func (e *Exchange) wsProcessUser(resp *StandardWebsocketResponse) error { for j := range wsUser[i].Positions.PerpetualFuturesPositions { var oSide order.Side if oSide, err = order.StringToOrderSide(wsUser[i].Positions.PerpetualFuturesPositions[j].PositionSide); err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - Err: err, - } + return err } var mType margin.Type if mType, err = margin.StringToMarginType(wsUser[i].Positions.PerpetualFuturesPositions[j].MarginType); err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - Err: err, - } + return err } allOrders = append(allOrders, order.Detail{ Pair: wsUser[i].Positions.PerpetualFuturesPositions[j].ProductID, @@ -308,10 +290,7 @@ func (e *Exchange) wsProcessUser(resp *StandardWebsocketResponse) error { for j := range wsUser[i].Positions.ExpiringFuturesPositions { var oSide order.Side if oSide, err = order.StringToOrderSide(wsUser[i].Positions.ExpiringFuturesPositions[j].Side); err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - Err: err, - } + return err } allOrders = append(allOrders, order.Detail{ Pair: wsUser[i].Positions.ExpiringFuturesPositions[j].ProductID, @@ -321,12 +300,11 @@ func (e *Exchange) wsProcessUser(resp *StandardWebsocketResponse) error { }) } } - e.Websocket.DataHandler <- allOrders - return nil + return e.Websocket.DataHandler.Send(ctx, allOrders) } // wsHandleData handles all the websocket data coming from the websocket connection -func (e *Exchange) wsHandleData(respRaw []byte) (*uint64, error) { +func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) (*uint64, error) { var resp StandardWebsocketResponse if err := json.Unmarshal(respRaw, &resp); err != nil { return nil, err @@ -342,17 +320,17 @@ func (e *Exchange) wsHandleData(respRaw []byte) (*uint64, error) { if err := json.Unmarshal(resp.Events, &wsStatus); err != nil { return &resp.Sequence, err } - e.Websocket.DataHandler <- wsStatus + return &resp.Sequence, e.Websocket.DataHandler.Send(ctx, wsStatus) case "ticker", "ticker_batch": - if err := e.wsProcessTicker(&resp); err != nil { + if err := e.wsProcessTicker(ctx, &resp); err != nil { return &resp.Sequence, err } case "candles": - if err := e.wsProcessCandle(&resp); err != nil { + if err := e.wsProcessCandle(ctx, &resp); err != nil { return &resp.Sequence, err } case "market_trades": - if err := e.wsProcessMarketTrades(&resp); err != nil { + if err := e.wsProcessMarketTrades(ctx, &resp); err != nil { return &resp.Sequence, err } case "l2_data": @@ -360,7 +338,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) (*uint64, error) { return &resp.Sequence, err } case "user": - if err := e.wsProcessUser(&resp); err != nil { + if err := e.wsProcessUser(ctx, &resp); err != nil { return &resp.Sequence, err } default: diff --git a/exchanges/coinut/coinut_websocket.go b/exchanges/coinut/coinut_websocket.go index 34f637fb1c7..a83ec1edcbd 100644 --- a/exchanges/coinut/coinut_websocket.go +++ b/exchanges/coinut/coinut_websocket.go @@ -87,9 +87,10 @@ func (e *Exchange) wsReadData(ctx context.Context) { if strings.HasPrefix(string(resp.Raw), "[") { var incoming []wsResponse - err := json.Unmarshal(resp.Raw, &incoming) - if err != nil { - e.Websocket.DataHandler <- err + if err := json.Unmarshal(resp.Raw, &incoming); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } continue } for i := range incoming { @@ -98,45 +99,43 @@ func (e *Exchange) wsReadData(ctx context.Context) { break } } - var individualJSON []byte - individualJSON, err = json.Marshal(incoming[i]) + individualJSON, err := json.Marshal(incoming[i]) if err != nil { - e.Websocket.DataHandler <- err + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } continue } - err = e.wsHandleData(ctx, individualJSON) - if err != nil { - e.Websocket.DataHandler <- err + if err := e.wsHandleData(ctx, individualJSON); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } } } else { - var incoming wsResponse - err := json.Unmarshal(resp.Raw, &incoming) - if err != nil { - e.Websocket.DataHandler <- err - continue - } - err = e.wsHandleData(ctx, resp.Raw) - if err != nil { - e.Websocket.DataHandler <- err + if err := e.wsHandleData(ctx, resp.Raw); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } } } } -func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { +func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { if strings.HasPrefix(string(respRaw), "[") { var orders []wsOrderContainer - err := json.Unmarshal(respRaw, &orders) - if err != nil { + if err := json.Unmarshal(respRaw, &orders); err != nil { return err } for i := range orders { - o, err2 := e.parseOrderContainer(&orders[i]) - if err2 != nil { - return err2 + o, err := e.parseOrderContainer(&orders[i]) + if err != nil { + return err + } + if err := e.Websocket.DataHandler.Send(ctx, o); err != nil { + return err } - e.Websocket.DataHandler <- o } return nil } @@ -176,13 +175,13 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- &order.Detail{ + return e.Websocket.DataHandler.Send(ctx, &order.Detail{ Exchange: e.Name, OrderID: strconv.FormatInt(cancel.OrderID, 10), Status: order.Cancelled, LastUpdated: time.Now(), AssetType: asset.Spot, - } + }) case "cancel_orders": var cancels WsCancelOrdersResponse err := json.Unmarshal(respRaw, &cancels) @@ -190,12 +189,14 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { return err } for i := range cancels.Results { - e.Websocket.DataHandler <- &order.Detail{ + if err := e.Websocket.DataHandler.Send(ctx, &order.Detail{ Exchange: e.Name, OrderID: strconv.FormatInt(cancels.Results[i].OrderID, 10), Status: order.Cancelled, LastUpdated: time.Now(), AssetType: asset.Spot, + }); err != nil { + return err } } case "trade_history": @@ -233,7 +234,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { return err } - e.Websocket.DataHandler <- &ticker.Price{ + return e.Websocket.DataHandler.Send(ctx, &ticker.Price{ ExchangeName: e.Name, Volume: wsTicker.Volume24, QuoteVolume: wsTicker.Volume24Quote, @@ -245,27 +246,21 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { LastUpdated: wsTicker.Timestamp.Time(), AssetType: asset.Spot, Pair: p, - } + }) case "inst_order_book": var orderbookSnapshot WsOrderbookSnapshot err := json.Unmarshal(respRaw, &orderbookSnapshot) if err != nil { return err } - err = e.WsProcessOrderbookSnapshot(&orderbookSnapshot) - if err != nil { - return err - } + return e.WsProcessOrderbookSnapshot(&orderbookSnapshot) case "inst_order_book_update": var orderbookUpdate WsOrderbookUpdate err := json.Unmarshal(respRaw, &orderbookUpdate) if err != nil { return err } - err = e.WsProcessOrderbookUpdate(&orderbookUpdate) - if err != nil { - return err - } + return e.WsProcessOrderbookUpdate(&orderbookUpdate) case "inst_trade": if !e.IsSaveTradeDataEnabled() { return nil @@ -291,10 +286,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { tSide, err := order.StringToOrderSide(tradeSnap.Trades[i].Side) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - Err: err, - } + return err } trades = append(trades, trade.Data{ @@ -333,10 +325,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { tSide, err := order.StringToOrderSide(tradeUpdate.Side) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - Err: err, - } + return err } return trade.AddTradesToBuffer(trade.Data{ @@ -359,10 +348,9 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- o + return e.Websocket.DataHandler.Send(ctx, o) default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)} - return nil + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)}) } return nil } @@ -391,30 +379,18 @@ func (e *Exchange) parseOrderContainer(oContainer *wsOrderContainer) (*order.Det if oContainer.Side != "" { oSide, err = order.StringToOrderSide(oContainer.Side) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } + return nil, err } } else if oContainer.Order.Side != "" { oSide, err = order.StringToOrderSide(oContainer.Order.Side) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } + return nil, err } } oStatus, err = stringToOrderStatus(oContainer.Reply, oContainer.OpenQuantity) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } + return nil, err } if oContainer.Status[0] != "OK" { return nil, fmt.Errorf("%s - Order rejected: %v", e.Name, oContainer.Status) @@ -438,11 +414,7 @@ func (e *Exchange) parseOrderContainer(oContainer *wsOrderContainer) (*order.Det if oContainer.Reply == "order_filled" { o.Side, err = order.StringToOrderSide(oContainer.Order.Side) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } + return nil, err } o.RemainingAmount = oContainer.Order.OpenQuantity o.Amount = oContainer.Order.Quantity diff --git a/exchanges/deribit/deribit_test.go b/exchanges/deribit/deribit_test.go index c81456d4ba4..990c61a9214 100644 --- a/exchanges/deribit/deribit_test.go +++ b/exchanges/deribit/deribit_test.go @@ -718,7 +718,7 @@ func TestWSProcessTrades(t *testing.T) { e := new(Exchange) require.NoError(t, testexch.Setup(e), "Setup instance must not error") testexch.FixtureToDataHandler(t, "testdata/wsAllTrades.json", e.wsHandleData) - close(e.Websocket.DataHandler) + e.Websocket.DataHandler.Close() a, p, err := getAssetPairByInstrument("BTC-PERPETUAL") require.NoError(t, err, "getAssetPairByInstrument must not error") @@ -745,11 +745,11 @@ func TestWSProcessTrades(t *testing.T) { AssetType: a, }, } - require.Len(t, e.Websocket.DataHandler, len(exp), "Must see the correct number of trades") - for resp := range e.Websocket.DataHandler { - switch v := resp.(type) { + require.Len(t, e.Websocket.DataHandler.C, len(exp), "Must see the correct number of trades") + for resp := range e.Websocket.DataHandler.C { + switch v := resp.Data.(type) { case trade.Data: - i := 1 - len(e.Websocket.DataHandler) + i := 1 - len(e.Websocket.DataHandler.C) require.Equalf(t, exp[i], v, "Trade [%d] must be correct", i) case error: t.Error(v) @@ -4092,26 +4092,6 @@ func TestGetResolutionFromInterval(t *testing.T) { } } -func TestGetValidatedCurrencyCode(t *testing.T) { - t.Parallel() - pairs := map[currency.Pair]string{ - currency.NewPairWithDelimiter(currencySOL, "21OCT22-20-C", "-"): currencySOL, - currency.NewPairWithDelimiter(currencyBTC, perpString, "-"): currencyBTC, - currency.NewPairWithDelimiter(currencyETH, perpString, "-"): currencyETH, - currency.NewPairWithDelimiter(currencySOL, perpString, "-"): currencySOL, - currency.NewPairWithDelimiter("AVAX_USDC", perpString, "-"): currencyUSDC, - currency.NewPairWithDelimiter(currencyBTC, "USDC", "_"): currencyBTC, - currency.NewPairWithDelimiter(currencyETH, "USDC", "_"): currencyETH, - currency.NewPairWithDelimiter("DOT", "USDC-PERPETUAL", "_"): currencyUSDC, - currency.NewPairWithDelimiter("DOT", "USDT-PERPETUAL", "_"): currencyUSDT, - currency.EMPTYPAIR: "any", - } - for x := range pairs { - result := getValidatedCurrencyCode(x) - require.Equalf(t, pairs[x], result, "expected: %s actual : %s for currency pair: %v", x, result, pairs[x]) - } -} - func TestGetCurrencyTradeURL(t *testing.T) { t.Parallel() _, err := e.GetCurrencyTradeURL(t.Context(), asset.Spot, currency.EMPTYPAIR) diff --git a/exchanges/deribit/deribit_websocket.go b/exchanges/deribit/deribit_websocket.go index 86c34334f13..71805e0522f 100644 --- a/exchanges/deribit/deribit_websocket.go +++ b/exchanges/deribit/deribit_websocket.go @@ -202,10 +202,10 @@ func (e *Exchange) wsReadData(ctx context.Context) { if resp.Raw == nil { return } - - err := e.wsHandleData(ctx, resp.Raw) - if err != nil { - e.Websocket.DataHandler <- err + if err := e.wsHandleData(ctx, resp.Raw); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } } } @@ -235,73 +235,72 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- announcement + return e.Websocket.DataHandler.Send(ctx, announcement) case "book": return e.processOrderbook(respRaw, channels) case "chart": - return e.processCandleChart(respRaw, channels) + return e.processCandleChart(ctx, respRaw, channels) case "deribit_price_index": indexPrice := &wsIndexPrice{} - return e.processData(respRaw, indexPrice) + return e.processData(ctx, respRaw, indexPrice) case "deribit_price_ranking": priceRankings := &wsRankingPrices{} - return e.processData(respRaw, priceRankings) + return e.processData(ctx, respRaw, priceRankings) case "deribit_price_statistics": priceStatistics := &wsPriceStatistics{} - return e.processData(respRaw, priceStatistics) + return e.processData(ctx, respRaw, priceStatistics) case "deribit_volatility_index": volatilityIndex := &wsVolatilityIndex{} - return e.processData(respRaw, volatilityIndex) + return e.processData(ctx, respRaw, volatilityIndex) case "estimated_expiration_price": estimatedExpirationPrice := &wsEstimatedExpirationPrice{} - return e.processData(respRaw, estimatedExpirationPrice) + return e.processData(ctx, respRaw, estimatedExpirationPrice) case "incremental_ticker": - return e.processIncrementalTicker(respRaw, channels) + return e.processIncrementalTicker(ctx, respRaw, channels) case "instrument": instrumentState := &wsInstrumentState{} - return e.processData(respRaw, instrumentState) + return e.processData(ctx, respRaw, instrumentState) case "markprice": markPriceOptions := []wsMarkPriceOptions{} - return e.processData(respRaw, markPriceOptions) + return e.processData(ctx, respRaw, markPriceOptions) case "perpetual": perpetualInterest := &wsPerpetualInterest{} - return e.processData(respRaw, perpetualInterest) + return e.processData(ctx, respRaw, perpetualInterest) case platformStateChannel: platformState := &wsPlatformState{} - return e.processData(respRaw, platformState) + return e.processData(ctx, respRaw, platformState) case "quote": // Quote ticker information. - return e.processQuoteTicker(respRaw, channels) + return e.processQuoteTicker(ctx, respRaw, channels) case "ticker": - return e.processInstrumentTicker(respRaw, channels) + return e.processInstrumentTicker(ctx, respRaw, channels) case "trades": - return e.processTrades(respRaw, channels) + return e.processTrades(ctx, respRaw, channels) case "user": switch channels[1] { case "access_log": accessLog := &wsAccessLog{} - return e.processData(respRaw, accessLog) + return e.processData(ctx, respRaw, accessLog) case "changes": - return e.processUserOrderChanges(respRaw, channels) + return e.processUserOrderChanges(ctx, respRaw, channels) case "lock": userLock := &WsUserLock{} - return e.processData(respRaw, userLock) + return e.processData(ctx, respRaw, userLock) case "mmp_trigger": data := &WsMMPTrigger{ Currency: channels[2], } - return e.processData(respRaw, data) + return e.processData(ctx, respRaw, data) case "orders": - return e.processUserOrders(respRaw, channels) + return e.processUserOrders(ctx, respRaw, channels) case "portfolio": portfolio := &wsUserPortfolio{} - return e.processData(respRaw, portfolio) + return e.processData(ctx, respRaw, portfolio) case "trades": - return e.processTrades(respRaw, channels) + return e.processTrades(ctx, respRaw, channels) default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{ Message: e.Name + websocket.UnhandledMessage + string(respRaw), - } - return nil + }) } case "public/test", "public/set_heartbeat": default: @@ -311,10 +310,9 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { return nil } default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{ Message: e.Name + websocket.UnhandledMessage + string(respRaw), - } - return nil + }) } } return nil @@ -331,7 +329,7 @@ func (e *Exchange) wsSendHeartbeat(ctx context.Context) { } } -func (e *Exchange) processUserOrders(respRaw []byte, channels []string) error { +func (e *Exchange) processUserOrders(ctx context.Context, respRaw []byte, channels []string) error { if len(channels) != 4 && len(channels) != 5 { return fmt.Errorf("%w, expected format 'user.orders.{instrument_name}.raw, user.orders.{instrument_name}.{interval}, user.orders.{kind}.{currency}.raw, or user.orders.{kind}.{currency}.{interval}', but found %s", common.ErrMalformedData, strings.Join(channels, ".")) } @@ -376,11 +374,10 @@ func (e *Exchange) processUserOrders(respRaw []byte, channels []string) error { Pair: cp, } } - e.Websocket.DataHandler <- orderDetails - return nil + return e.Websocket.DataHandler.Send(ctx, orderDetails) } -func (e *Exchange) processUserOrderChanges(respRaw []byte, channels []string) error { +func (e *Exchange) processUserOrderChanges(ctx context.Context, respRaw []byte, channels []string) error { if len(channels) < 4 || len(channels) > 5 { return fmt.Errorf("%w, expected format 'trades.{instrument_name}.{interval} or trades.{kind}.{currency}.{interval}', but found %s", common.ErrMalformedData, strings.Join(channels, ".")) } @@ -454,12 +451,13 @@ func (e *Exchange) processUserOrderChanges(respRaw []byte, channels []string) er Pair: cp, } } - e.Websocket.DataHandler <- orders - e.Websocket.DataHandler <- changeData.Positions - return nil + if err := e.Websocket.DataHandler.Send(ctx, orders); err != nil { + return err + } + return e.Websocket.DataHandler.Send(ctx, changeData.Positions) } -func (e *Exchange) processQuoteTicker(respRaw []byte, channels []string) error { +func (e *Exchange) processQuoteTicker(ctx context.Context, respRaw []byte, channels []string) error { a, cp, err := getAssetPairByInstrument(channels[1]) if err != nil { return err @@ -471,7 +469,7 @@ func (e *Exchange) processQuoteTicker(respRaw []byte, channels []string) error { if err != nil { return err } - e.Websocket.DataHandler <- &ticker.Price{ + return e.Websocket.DataHandler.Send(ctx, &ticker.Price{ ExchangeName: e.Name, Pair: cp, AssetType: a, @@ -480,11 +478,10 @@ func (e *Exchange) processQuoteTicker(respRaw []byte, channels []string) error { Ask: quoteTicker.BestAskPrice, BidSize: quoteTicker.BestBidAmount, AskSize: quoteTicker.BestAskAmount, - } - return nil + }) } -func (e *Exchange) processTrades(respRaw []byte, channels []string) error { +func (e *Exchange) processTrades(ctx context.Context, respRaw []byte, channels []string) error { tradeFeed := e.IsTradeFeedEnabled() saveTradeData := e.IsSaveTradeDataEnabled() if !tradeFeed && !saveTradeData { @@ -525,7 +522,9 @@ func (e *Exchange) processTrades(respRaw []byte, channels []string) error { } if tradeFeed { for i := range tradesData { - e.Websocket.DataHandler <- tradesData[i] + if err := e.Websocket.DataHandler.Send(ctx, tradesData[i]); err != nil { + return err + } } } if saveTradeData { @@ -534,7 +533,7 @@ func (e *Exchange) processTrades(respRaw []byte, channels []string) error { return nil } -func (e *Exchange) processIncrementalTicker(respRaw []byte, channels []string) error { +func (e *Exchange) processIncrementalTicker(ctx context.Context, respRaw []byte, channels []string) error { if len(channels) != 2 { return fmt.Errorf("%w, expected format 'incremental_ticker.{instrument_name}', but found %s", common.ErrMalformedData, strings.Join(channels, ".")) } @@ -549,7 +548,7 @@ func (e *Exchange) processIncrementalTicker(respRaw []byte, channels []string) e if err != nil { return err } - e.Websocket.DataHandler <- &ticker.Price{ + return e.Websocket.DataHandler.Send(ctx, &ticker.Price{ ExchangeName: e.Name, Pair: cp, AssetType: a, @@ -562,18 +561,17 @@ func (e *Exchange) processIncrementalTicker(respRaw []byte, channels []string) e QuoteVolume: incrementalTicker.Stats.VolumeUsd, Ask: incrementalTicker.ImpliedAsk, Bid: incrementalTicker.ImpliedBid, - } - return nil + }) } -func (e *Exchange) processInstrumentTicker(respRaw []byte, channels []string) error { +func (e *Exchange) processInstrumentTicker(ctx context.Context, respRaw []byte, channels []string) error { if len(channels) != 3 { return fmt.Errorf("%w, expected format 'ticker.{instrument_name}.{interval}', but found %s", common.ErrMalformedData, strings.Join(channels, ".")) } - return e.processTicker(respRaw, channels) + return e.processTicker(ctx, respRaw, channels) } -func (e *Exchange) processTicker(respRaw []byte, channels []string) error { +func (e *Exchange) processTicker(ctx context.Context, respRaw []byte, channels []string) error { a, cp, err := getAssetPairByInstrument(channels[1]) if err != nil { return err @@ -606,22 +604,20 @@ func (e *Exchange) processTicker(respRaw []byte, channels []string) error { tickerPrice.Ask = tickerPriceResponse.ImpliedAsk tickerPrice.Bid = tickerPriceResponse.ImpliedBid } - e.Websocket.DataHandler <- tickerPrice - return nil + return e.Websocket.DataHandler.Send(ctx, tickerPrice) } -func (e *Exchange) processData(respRaw []byte, result any) error { +func (e *Exchange) processData(ctx context.Context, respRaw []byte, result any) error { var response wsResponse response.Params.Data = result err := json.Unmarshal(respRaw, &response) if err != nil { return err } - e.Websocket.DataHandler <- result - return nil + return e.Websocket.DataHandler.Send(ctx, result) } -func (e *Exchange) processCandleChart(respRaw []byte, channels []string) error { +func (e *Exchange) processCandleChart(ctx context.Context, respRaw []byte, channels []string) error { if len(channels) != 4 { return fmt.Errorf("%w, expected format 'chart.trades.{instrument_name}.{resolution}', but found %s", common.ErrInvalidResponse, strings.Join(channels, ".")) } @@ -636,7 +632,7 @@ func (e *Exchange) processCandleChart(respRaw []byte, channels []string) error { if err != nil { return err } - e.Websocket.DataHandler <- websocket.KlineData{ + return e.Websocket.DataHandler.Send(ctx, websocket.KlineData{ Timestamp: candleData.Tick.Time(), Pair: cp, AssetType: a, @@ -646,8 +642,7 @@ func (e *Exchange) processCandleChart(respRaw []byte, channels []string) error { LowPrice: candleData.Low, ClosePrice: candleData.Close, Volume: candleData.Volume, - } - return nil + }) } func (e *Exchange) processOrderbook(respRaw []byte, channels []string) error { @@ -869,24 +864,6 @@ func (e *Exchange) handleSubscription(ctx context.Context, method string, subs s return err } -func getValidatedCurrencyCode(pair currency.Pair) string { - currencyCode := pair.Base.Upper().String() - switch currencyCode { - case currencyBTC, currencyETH, - currencySOL, currencyUSDT, - currencyUSDC, currencyEURR: - return currencyCode - default: - switch { - case strings.Contains(pair.String(), currencyUSDC): - return currencyUSDC - case strings.Contains(pair.String(), currencyUSDT): - return currencyUSDT - } - return "any" - } -} - func channelName(s *subscription.Subscription) string { if name, ok := subscriptionNames[s.Channel]; ok { return name diff --git a/exchanges/exchange.go b/exchanges/exchange.go index 4bca1263193..fd6d73238f0 100644 --- a/exchanges/exchange.go +++ b/exchanges/exchange.go @@ -94,7 +94,7 @@ func (b *Base) SetClientProxyAddress(addr string) error { } if b.Websocket != nil { - err = b.Websocket.SetProxyAddress(addr) + err = b.Websocket.SetProxyAddress(context.TODO(), addr) if err != nil { return err } @@ -1071,7 +1071,7 @@ func (b *Base) FlushWebsocketChannels() error { if b.Websocket == nil { return nil } - return b.Websocket.FlushChannels() + return b.Websocket.FlushChannels(context.TODO()) } // SubscribeToWebsocketChannels appends to ChannelsToSubscribe @@ -1080,7 +1080,7 @@ func (b *Base) SubscribeToWebsocketChannels(channels subscription.List) error { if b.Websocket == nil { return common.ErrFunctionNotSupported } - return b.Websocket.SubscribeToChannels(b.Websocket.Conn, channels) + return b.Websocket.SubscribeToChannels(context.TODO(), b.Websocket.Conn, channels) } // UnsubscribeToWebsocketChannels removes from ChannelsToSubscribe @@ -1089,7 +1089,7 @@ func (b *Base) UnsubscribeToWebsocketChannels(channels subscription.List) error if b.Websocket == nil { return common.ErrFunctionNotSupported } - return b.Websocket.UnsubscribeChannels(b.Websocket.Conn, channels) + return b.Websocket.UnsubscribeChannels(context.TODO(), b.Websocket.Conn, channels) } // GetSubscriptions returns a copied list of subscriptions diff --git a/exchanges/fill/fill.go b/exchanges/fill/fill.go index 04755fe466e..ecf93277b20 100644 --- a/exchanges/fill/fill.go +++ b/exchanges/fill/fill.go @@ -1,12 +1,17 @@ package fill -import "errors" +import ( + "context" + "errors" + + "github.com/thrasher-corp/gocryptotrader/exchange/stream" +) // ErrFeedDisabled is an error that indicates the fill feed is disabled var ErrFeedDisabled = errors.New("fill feed disabled") // Setup sets up the fill processor -func (f *Fills) Setup(fillsFeedEnabled bool, c chan any) { +func (f *Fills) Setup(fillsFeedEnabled bool, c *stream.Relay) { f.dataHandler = c f.fillsFeedEnabled = fillsFeedEnabled } @@ -14,6 +19,7 @@ func (f *Fills) Setup(fillsFeedEnabled bool, c chan any) { // Update disseminates fill data through the data channel if so // configured func (f *Fills) Update(data ...Data) error { + ctx := context.TODO() if len(data) == 0 { // nothing to do return nil @@ -23,7 +29,5 @@ func (f *Fills) Update(data ...Data) error { return ErrFeedDisabled } - f.dataHandler <- data - - return nil + return f.dataHandler.Send(ctx, data) } diff --git a/exchanges/fill/fill_test.go b/exchanges/fill/fill_test.go index 87250501d1d..0d5dfbe3bbc 100644 --- a/exchanges/fill/fill_test.go +++ b/exchanges/fill/fill_test.go @@ -5,13 +5,13 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/thrasher-corp/gocryptotrader/exchange/stream" ) // TestSetup tests the setup function of the Fills struct func TestSetup(t *testing.T) { fill := &Fills{} - channel := make(chan any) - fill.Setup(true, channel) + fill.Setup(true, stream.NewRelay(1)) if fill.dataHandler == nil { t.Error("expected dataHandler to be set") @@ -24,15 +24,14 @@ func TestSetup(t *testing.T) { // TestUpdateDisabledFeed tests the Update function when fillsFeedEnabled is false func TestUpdateDisabledFeed(t *testing.T) { - channel := make(chan any, 1) - fill := Fills{dataHandler: channel, fillsFeedEnabled: false} + fill := Fills{dataHandler: stream.NewRelay(1), fillsFeedEnabled: false} // Send a test data to the Update function testData := Data{Timestamp: time.Now(), Price: 15.2, Amount: 3.2} assert.ErrorIs(t, fill.Update(testData), ErrFeedDisabled) select { - case <-channel: + case <-fill.dataHandler.C: t.Errorf("Expected no data on channel, got data") default: // nothing to do @@ -41,16 +40,15 @@ func TestUpdateDisabledFeed(t *testing.T) { // TestUpdate tests the Update function of the Fills struct. func TestUpdate(t *testing.T) { - channel := make(chan any, 1) - fill := &Fills{dataHandler: channel, fillsFeedEnabled: true} + fill := &Fills{dataHandler: stream.NewRelay(1), fillsFeedEnabled: true} receivedData := Data{Timestamp: time.Now(), Price: 15.2, Amount: 3.2} if err := fill.Update(receivedData); err != nil { t.Errorf("Update returned error %v", err) } select { - case data := <-channel: - dataSlice, ok := data.([]Data) + case data := <-fill.dataHandler.C: + dataSlice, ok := data.Data.([]Data) if !ok { t.Errorf("expected []Data, got %T", data) } @@ -65,14 +63,13 @@ func TestUpdate(t *testing.T) { // TestUpdateNoData tests the Update function with no Data objects func TestUpdateNoData(t *testing.T) { - channel := make(chan any, 1) - fill := &Fills{dataHandler: channel, fillsFeedEnabled: true} + fill := &Fills{dataHandler: stream.NewRelay(1), fillsFeedEnabled: true} if err := fill.Update(); err != nil { t.Errorf("Update returned error %v", err) } select { - case <-channel: + case <-fill.dataHandler.C: t.Errorf("Expected no data on channel, got data") default: // pass, nothing to do @@ -81,8 +78,7 @@ func TestUpdateNoData(t *testing.T) { // TestUpdateMultipleData tests the Update function with multiple Data objects func TestUpdateMultipleData(t *testing.T) { - channel := make(chan any, 2) - fill := &Fills{dataHandler: channel, fillsFeedEnabled: true} + fill := &Fills{dataHandler: stream.NewRelay(2), fillsFeedEnabled: true} receivedData := Data{Timestamp: time.Now(), Price: 15.2, Amount: 3.2} receivedData2 := Data{Timestamp: time.Now(), Price: 18.2, Amount: 9.0} if err := fill.Update(receivedData, receivedData2); err != nil { @@ -90,8 +86,8 @@ func TestUpdateMultipleData(t *testing.T) { } select { - case data := <-channel: - dataSlice, ok := data.([]Data) + case data := <-fill.dataHandler.C: + dataSlice, ok := data.Data.([]Data) if !ok { t.Errorf("expected []Data, got %T", data) } diff --git a/exchanges/fill/fill_types.go b/exchanges/fill/fill_types.go index efc2505d942..69d4842f635 100644 --- a/exchanges/fill/fill_types.go +++ b/exchanges/fill/fill_types.go @@ -4,13 +4,14 @@ import ( "time" "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchange/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" ) // Fills is used to hold data and methods related to fill dissemination type Fills struct { - dataHandler chan any + dataHandler *stream.Relay fillsFeedEnabled bool } diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index 435ecfc282d..6fabce48f22 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -2071,10 +2071,10 @@ func TestFuturesDataHandler(t *testing.T) { } return e.WsHandleFuturesData(ctx, nil, m, asset.CoinMarginedFutures) }) - close(e.Websocket.DataHandler) - assert.Len(t, e.Websocket.DataHandler, 14, "Should see the correct number of messages") - for resp := range e.Websocket.DataHandler { - if err, isErr := resp.(error); isErr { + e.Websocket.DataHandler.Close() + assert.Len(t, e.Websocket.DataHandler.C, 14, "Should see the correct number of messages") + for resp := range e.Websocket.DataHandler.C { + if err, isErr := resp.Data.(error); isErr { assert.NoError(t, err, "Should not get any errors down the data handler") } } diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 06bc719d958..4d8cc4d9073 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -182,11 +182,11 @@ func (e *Exchange) WsHandleSpotData(ctx context.Context, conn websocket.Connecti switch push.Channel { // TODO: Convert function params below to only use push.Result case spotTickerChannel: - return e.processTicker(push.Result, push.Time) + return e.processTicker(ctx, push.Result, push.Time) case spotTradesChannel: return e.processTrades(push.Result) case spotCandlesticksChannel: - return e.processCandlestick(push.Result) + return e.processCandlestick(ctx, push.Result) case spotOrderbookTickerChannel: return e.processOrderbookTicker(push.Result, push.Time) case spotOrderbookUpdateChannel: @@ -194,9 +194,9 @@ func (e *Exchange) WsHandleSpotData(ctx context.Context, conn websocket.Connecti case spotOrderbookChannel: return e.processOrderbookSnapshot(push.Result, push.Time) case spotOrderbookV2: - return e.processOrderbookUpdateWithSnapshot(conn, push.Result, push.Time, asset.Spot) + return e.processOrderbookUpdateWithSnapshot(ctx, conn, push.Result, push.Time, asset.Spot) case spotOrdersChannel: - return e.processSpotOrders(respRaw) + return e.processSpotOrders(ctx, respRaw) case spotUserTradesChannel: return e.processUserPersonalTrades(respRaw) case spotBalancesChannel: @@ -204,17 +204,16 @@ func (e *Exchange) WsHandleSpotData(ctx context.Context, conn websocket.Connecti case marginBalancesChannel: return e.processMarginBalances(ctx, respRaw) case spotFundingBalanceChannel: - return e.processFundingBalances(respRaw) + return e.processFundingBalances(ctx, respRaw) case crossMarginBalanceChannel: return e.processCrossMarginBalance(ctx, respRaw) case crossMarginLoanChannel: - return e.processCrossMarginLoans(respRaw) + return e.processCrossMarginLoans(ctx, respRaw) case spotPongChannel: default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{ Message: e.Name + websocket.UnhandledMessage + string(respRaw), - } - return errors.New(websocket.UnhandledMessage) + }) } return nil } @@ -258,7 +257,7 @@ func parseWSHeader(msg []byte) (r *WSResponse, errs error) { return r, errs } -func (e *Exchange) processTicker(incoming []byte, pushTime time.Time) error { +func (e *Exchange) processTicker(ctx context.Context, incoming []byte, pushTime time.Time) error { var data WsTicker if err := json.Unmarshal(incoming, &data); err != nil { return err @@ -281,8 +280,7 @@ func (e *Exchange) processTicker(incoming []byte, pushTime time.Time) error { }) } } - e.Websocket.DataHandler <- out - return nil + return e.Websocket.DataHandler.Send(ctx, out) } func (e *Exchange) processTrades(incoming []byte) error { @@ -321,7 +319,7 @@ func (e *Exchange) processTrades(incoming []byte) error { return nil } -func (e *Exchange) processCandlestick(incoming []byte) error { +func (e *Exchange) processCandlestick(ctx context.Context, incoming []byte) error { var data WsCandlesticks if err := json.Unmarshal(incoming, &data); err != nil { return err @@ -352,8 +350,7 @@ func (e *Exchange) processCandlestick(incoming []byte) error { }) } } - e.Websocket.DataHandler <- out - return nil + return e.Websocket.DataHandler.Send(ctx, out) } func (e *Exchange) processOrderbookTicker(incoming []byte, lastPushed time.Time) error { @@ -413,7 +410,7 @@ func (e *Exchange) processOrderbookSnapshot(incoming []byte, lastPushed time.Tim return nil } -func (e *Exchange) processOrderbookUpdateWithSnapshot(conn websocket.Connection, incoming []byte, lastPushed time.Time, a asset.Item) error { +func (e *Exchange) processOrderbookUpdateWithSnapshot(ctx context.Context, conn websocket.Connection, incoming []byte, lastPushed time.Time, a asset.Item) error { var data WsOrderbookUpdateWithSnapshot if err := json.Unmarshal(incoming, &data); err != nil { return err @@ -452,7 +449,7 @@ func (e *Exchange) processOrderbookUpdateWithSnapshot(conn websocket.Connection, lastUpdateID, err := e.Websocket.Orderbook.LastUpdateID(pair, a) if err != nil || lastUpdateID+1 != data.FirstUpdateID { - return common.AppendError(err, e.wsOBResubMgr.Resubscribe(e, conn, data.Channel, pair, a)) + return common.AppendError(err, e.wsOBResubMgr.Resubscribe(ctx, e, conn, data.Channel, pair, a)) } return e.Websocket.Orderbook.Update(&orderbook.Update{ Pair: pair, @@ -466,7 +463,7 @@ func (e *Exchange) processOrderbookUpdateWithSnapshot(conn websocket.Connection, }) } -func (e *Exchange) processSpotOrders(data []byte) error { +func (e *Exchange) processSpotOrders(ctx context.Context, data []byte) error { resp := struct { Time types.Time `json:"time"` Channel string `json:"channel"` @@ -506,8 +503,7 @@ func (e *Exchange) processSpotOrders(data []byte) error { LastUpdated: resp.Result[x].UpdateTime.Time(), } } - e.Websocket.DataHandler <- details - return nil + return e.Websocket.DataHandler.Send(ctx, details) } func (e *Exchange) processUserPersonalTrades(data []byte) error { @@ -565,8 +561,7 @@ func (e *Exchange) processSpotBalances(ctx context.Context, data []byte) error { if err := e.Accounts.Save(ctx, subAccts, false); err != nil { return err } - e.Websocket.DataHandler <- subAccts - return nil + return e.Websocket.DataHandler.Send(ctx, subAccts) } func (e *Exchange) processMarginBalances(ctx context.Context, data []byte) error { @@ -593,11 +588,10 @@ func (e *Exchange) processMarginBalances(ctx context.Context, data []byte) error if err := e.Accounts.Save(ctx, subAccts, false); err != nil { return err } - e.Websocket.DataHandler <- subAccts - return nil + return e.Websocket.DataHandler.Send(ctx, subAccts) } -func (e *Exchange) processFundingBalances(data []byte) error { +func (e *Exchange) processFundingBalances(ctx context.Context, data []byte) error { resp := struct { Time types.Time `json:"time"` Channel string `json:"channel"` @@ -608,8 +602,7 @@ func (e *Exchange) processFundingBalances(data []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- resp - return nil + return e.Websocket.DataHandler.Send(ctx, resp) } func (e *Exchange) processCrossMarginBalance(ctx context.Context, data []byte) error { @@ -636,11 +629,10 @@ func (e *Exchange) processCrossMarginBalance(ctx context.Context, data []byte) e if err := e.Accounts.Save(ctx, subAccts, false); err != nil { return err } - e.Websocket.DataHandler <- subAccts - return nil + return e.Websocket.DataHandler.Send(ctx, subAccts) } -func (e *Exchange) processCrossMarginLoans(data []byte) error { +func (e *Exchange) processCrossMarginLoans(ctx context.Context, data []byte) error { resp := struct { Time types.Time `json:"time"` Channel string `json:"channel"` @@ -651,8 +643,7 @@ func (e *Exchange) processCrossMarginLoans(data []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- resp - return nil + return e.Websocket.DataHandler.Send(ctx, resp) } // generateSubscriptionsSpot returns configured subscriptions diff --git a/exchanges/gateio/gateio_websocket_futures.go b/exchanges/gateio/gateio_websocket_futures.go index 4511965f1e9..54e1a3978a9 100644 --- a/exchanges/gateio/gateio_websocket_futures.go +++ b/exchanges/gateio/gateio_websocket_futures.go @@ -152,47 +152,45 @@ func (e *Exchange) WsHandleFuturesData(ctx context.Context, conn websocket.Conne switch push.Channel { case futuresTickersChannel: - return e.processFuturesTickers(respRaw, a) + return e.processFuturesTickers(ctx, respRaw, a) case futuresTradesChannel: return e.processFuturesTrades(respRaw, a) case futuresOrderbookChannel: return e.processFuturesOrderbookSnapshot(push.Event, push.Result, a, push.Time) case futuresOrderbookTickerChannel: - return e.processFuturesOrderbookTicker(push.Result) + return e.processFuturesOrderbookTicker(ctx, push.Result) case futuresOrderbookUpdateChannel: return e.processFuturesOrderbookUpdate(ctx, push.Result, a, push.Time) case futuresCandlesticksChannel: - return e.processFuturesCandlesticks(respRaw, a) + return e.processFuturesCandlesticks(ctx, respRaw, a) case futuresOrdersChannel: processed, err := e.processFuturesOrdersPushData(respRaw, a) if err != nil { return err } - e.Websocket.DataHandler <- processed - return nil + return e.Websocket.DataHandler.Send(ctx, processed) case futuresUserTradesChannel: return e.procesFuturesUserTrades(respRaw, a) case futuresLiquidatesChannel: - return e.processFuturesLiquidatesNotification(respRaw) + return e.processFuturesLiquidatesNotification(ctx, respRaw) case futuresAutoDeleveragesChannel: - return e.processFuturesAutoDeleveragesNotification(respRaw) + return e.processFuturesAutoDeleveragesNotification(ctx, respRaw) case futuresAutoPositionCloseChannel: - return e.processPositionCloseData(respRaw) + return e.processPositionCloseData(ctx, respRaw) case futuresBalancesChannel: return e.processBalancePushData(ctx, push.Result, a) case futuresReduceRiskLimitsChannel: - return e.processFuturesReduceRiskLimitNotification(respRaw) + return e.processFuturesReduceRiskLimitNotification(ctx, respRaw) case futuresPositionsChannel: - return e.processFuturesPositionsNotification(respRaw) + return e.processFuturesPositionsNotification(ctx, respRaw) case futuresAutoOrdersChannel: - return e.processFuturesAutoOrderPushData(respRaw) + return e.processFuturesAutoOrderPushData(ctx, respRaw) case "futures.pong": return nil default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{ Message: e.Name + websocket.UnhandledMessage + string(respRaw), - } - return errors.New(websocket.UnhandledMessage) + }) } } @@ -293,7 +291,7 @@ func (e *Exchange) generateFuturesPayload(ctx context.Context, event string, cha return outbound, nil } -func (e *Exchange) processFuturesTickers(data []byte, assetType asset.Item) error { +func (e *Exchange) processFuturesTickers(ctx context.Context, data []byte, assetType asset.Item) error { resp := struct { Time types.Time `json:"time"` Channel string `json:"channel"` @@ -318,8 +316,7 @@ func (e *Exchange) processFuturesTickers(data []byte, assetType asset.Item) erro LastUpdated: resp.Time.Time(), } } - e.Websocket.DataHandler <- tickerPriceDatas - return nil + return e.Websocket.DataHandler.Send(ctx, tickerPriceDatas) } func (e *Exchange) processFuturesTrades(data []byte, assetType asset.Item) error { @@ -354,7 +351,7 @@ func (e *Exchange) processFuturesTrades(data []byte, assetType asset.Item) error return e.Websocket.Trade.Update(saveTradeData, trades...) } -func (e *Exchange) processFuturesCandlesticks(data []byte, assetType asset.Item) error { +func (e *Exchange) processFuturesCandlesticks(ctx context.Context, data []byte, assetType asset.Item) error { resp := struct { Time types.Time `json:"time"` Channel string `json:"channel"` @@ -388,18 +385,16 @@ func (e *Exchange) processFuturesCandlesticks(data []byte, assetType asset.Item) Volume: resp.Result[x].Volume, } } - e.Websocket.DataHandler <- klineDatas - return nil + return e.Websocket.DataHandler.Send(ctx, klineDatas) } -func (e *Exchange) processFuturesOrderbookTicker(incoming []byte) error { +func (e *Exchange) processFuturesOrderbookTicker(ctx context.Context, incoming []byte) error { var data WsFuturesOrderbookTicker err := json.Unmarshal(incoming, &data) if err != nil { return err } - e.Websocket.DataHandler <- data - return nil + return e.Websocket.DataHandler.Send(ctx, data) } func (e *Exchange) processFuturesOrderbookUpdate(ctx context.Context, incoming []byte, a asset.Item, pushTime time.Time) error { @@ -532,11 +527,7 @@ func (e *Exchange) processFuturesOrdersPushData(data []byte, assetType asset.Ite status, err = order.StringToOrderStatus(resp.Result[x].Status) } if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: strconv.FormatInt(resp.Result[x].ID, 10), - Err: err, - } + return nil, err } orderDetails[x] = order.Detail{ @@ -588,7 +579,7 @@ func (e *Exchange) procesFuturesUserTrades(data []byte, assetType asset.Item) er return e.Websocket.Fills.Update(fills...) } -func (e *Exchange) processFuturesLiquidatesNotification(data []byte) error { +func (e *Exchange) processFuturesLiquidatesNotification(ctx context.Context, data []byte) error { resp := struct { Time types.Time `json:"time"` Channel string `json:"channel"` @@ -599,11 +590,10 @@ func (e *Exchange) processFuturesLiquidatesNotification(data []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- &resp - return nil + return e.Websocket.DataHandler.Send(ctx, &resp) } -func (e *Exchange) processFuturesAutoDeleveragesNotification(data []byte) error { +func (e *Exchange) processFuturesAutoDeleveragesNotification(ctx context.Context, data []byte) error { resp := struct { Time types.Time `json:"time"` Channel string `json:"channel"` @@ -614,11 +604,10 @@ func (e *Exchange) processFuturesAutoDeleveragesNotification(data []byte) error if err != nil { return err } - e.Websocket.DataHandler <- &resp - return nil + return e.Websocket.DataHandler.Send(ctx, &resp) } -func (e *Exchange) processPositionCloseData(data []byte) error { +func (e *Exchange) processPositionCloseData(ctx context.Context, data []byte) error { resp := struct { Time types.Time `json:"time"` Channel string `json:"channel"` @@ -629,8 +618,7 @@ func (e *Exchange) processPositionCloseData(data []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- &resp - return nil + return e.Websocket.DataHandler.Send(ctx, &resp) } func (e *Exchange) processBalancePushData(ctx context.Context, data []byte, assetType asset.Item) error { @@ -654,14 +642,13 @@ func (e *Exchange) processBalancePushData(ctx context.Context, data []byte, asse }) subAccts = subAccts.Merge(a) } - err := e.Accounts.Save(ctx, subAccts, false) - if err == nil { - e.Websocket.DataHandler <- subAccts + if err := e.Accounts.Save(ctx, subAccts, false); err != nil { + return err } - return err + return e.Websocket.DataHandler.Send(ctx, subAccts) } -func (e *Exchange) processFuturesReduceRiskLimitNotification(data []byte) error { +func (e *Exchange) processFuturesReduceRiskLimitNotification(ctx context.Context, data []byte) error { resp := struct { Time types.Time `json:"time"` Channel string `json:"channel"` @@ -672,11 +659,10 @@ func (e *Exchange) processFuturesReduceRiskLimitNotification(data []byte) error if err != nil { return err } - e.Websocket.DataHandler <- &resp - return nil + return e.Websocket.DataHandler.Send(ctx, &resp) } -func (e *Exchange) processFuturesPositionsNotification(data []byte) error { +func (e *Exchange) processFuturesPositionsNotification(ctx context.Context, data []byte) error { resp := struct { Time types.Time `json:"time"` Channel string `json:"channel"` @@ -687,11 +673,10 @@ func (e *Exchange) processFuturesPositionsNotification(data []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- &resp - return nil + return e.Websocket.DataHandler.Send(ctx, &resp) } -func (e *Exchange) processFuturesAutoOrderPushData(data []byte) error { +func (e *Exchange) processFuturesAutoOrderPushData(ctx context.Context, data []byte) error { resp := struct { Time types.Time `json:"time"` Channel string `json:"channel"` @@ -702,6 +687,5 @@ func (e *Exchange) processFuturesAutoOrderPushData(data []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- &resp - return nil + return e.Websocket.DataHandler.Send(ctx, &resp) } diff --git a/exchanges/gateio/gateio_websocket_option.go b/exchanges/gateio/gateio_websocket_option.go index 735ed88cfd6..7c62e33271f 100644 --- a/exchanges/gateio/gateio_websocket_option.go +++ b/exchanges/gateio/gateio_websocket_option.go @@ -296,60 +296,59 @@ func (e *Exchange) WsHandleOptionsData(ctx context.Context, conn websocket.Conne switch push.Channel { case optionsContractTickersChannel: - return e.processOptionsContractTickers(push.Result) + return e.processOptionsContractTickers(ctx, push.Result) case optionsUnderlyingTickersChannel: - return e.processOptionsUnderlyingTicker(push.Result) + return e.processOptionsUnderlyingTicker(ctx, push.Result) case optionsTradesChannel, optionsUnderlyingTradesChannel: return e.processOptionsTradesPushData(respRaw) case optionsUnderlyingPriceChannel: - return e.processOptionsUnderlyingPricePushData(push.Result) + return e.processOptionsUnderlyingPricePushData(ctx, push.Result) case optionsMarkPriceChannel: - return e.processOptionsMarkPrice(push.Result) + return e.processOptionsMarkPrice(ctx, push.Result) case optionsSettlementChannel: - return e.processOptionsSettlementPushData(push.Result) + return e.processOptionsSettlementPushData(ctx, push.Result) case optionsContractsChannel: - return e.processOptionsContractPushData(push.Result) + return e.processOptionsContractPushData(ctx, push.Result) case optionsContractCandlesticksChannel, optionsUnderlyingCandlesticksChannel: - return e.processOptionsCandlestickPushData(respRaw) + return e.processOptionsCandlestickPushData(ctx, respRaw) case optionsOrderbookChannel: return e.processOptionsOrderbookSnapshotPushData(push.Event, push.Result, push.Time) case optionsOrderbookTickerChannel: - return e.processOrderbookTickerPushData(respRaw) + return e.processOrderbookTickerPushData(ctx, respRaw) case optionsOrderbookUpdateChannel: return e.processOptionsOrderbookUpdate(ctx, push.Result, asset.Options, push.Time) case optionsOrdersChannel: - return e.processOptionsOrderPushData(respRaw) + return e.processOptionsOrderPushData(ctx, respRaw) case optionsUserTradesChannel: return e.processOptionsUserTradesPushData(respRaw) case optionsLiquidatesChannel: - return e.processOptionsLiquidatesPushData(respRaw) + return e.processOptionsLiquidatesPushData(ctx, respRaw) case optionsUserSettlementChannel: - return e.processOptionsUsersPersonalSettlementsPushData(respRaw) + return e.processOptionsUsersPersonalSettlementsPushData(ctx, respRaw) case optionsPositionCloseChannel: - return e.processPositionCloseData(respRaw) + return e.processPositionCloseData(ctx, respRaw) case optionsBalancesChannel: return e.processBalancePushData(ctx, push.Result, asset.Options) case optionsPositionsChannel: - return e.processOptionsPositionPushData(respRaw) + return e.processOptionsPositionPushData(ctx, respRaw) case "options.pong": return nil default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{ Message: e.Name + websocket.UnhandledMessage + string(respRaw), - } - return errors.New(websocket.UnhandledMessage) + }) } } -func (e *Exchange) processOptionsContractTickers(incoming []byte) error { +func (e *Exchange) processOptionsContractTickers(ctx context.Context, incoming []byte) error { var data OptionsTicker err := json.Unmarshal(incoming, &data) if err != nil { return err } - e.Websocket.DataHandler <- &ticker.Price{ + return e.Websocket.DataHandler.Send(ctx, &ticker.Price{ Pair: data.Name, Last: data.LastPrice.Float64(), Bid: data.Bid1Price.Float64(), @@ -358,18 +357,16 @@ func (e *Exchange) processOptionsContractTickers(incoming []byte) error { BidSize: data.Bid1Size, ExchangeName: e.Name, AssetType: asset.Options, - } - return nil + }) } -func (e *Exchange) processOptionsUnderlyingTicker(incoming []byte) error { +func (e *Exchange) processOptionsUnderlyingTicker(ctx context.Context, incoming []byte) error { var data WsOptionUnderlyingTicker err := json.Unmarshal(incoming, &data) if err != nil { return err } - e.Websocket.DataHandler <- &data - return nil + return e.Websocket.DataHandler.Send(ctx, &data) } func (e *Exchange) processOptionsTradesPushData(data []byte) error { @@ -403,47 +400,43 @@ func (e *Exchange) processOptionsTradesPushData(data []byte) error { return e.Websocket.Trade.Update(saveTradeData, trades...) } -func (e *Exchange) processOptionsUnderlyingPricePushData(incoming []byte) error { +func (e *Exchange) processOptionsUnderlyingPricePushData(ctx context.Context, incoming []byte) error { var data WsOptionsUnderlyingPrice err := json.Unmarshal(incoming, &data) if err != nil { return err } - e.Websocket.DataHandler <- &data - return nil + return e.Websocket.DataHandler.Send(ctx, &data) } -func (e *Exchange) processOptionsMarkPrice(incoming []byte) error { +func (e *Exchange) processOptionsMarkPrice(ctx context.Context, incoming []byte) error { var data WsOptionsMarkPrice err := json.Unmarshal(incoming, &data) if err != nil { return err } - e.Websocket.DataHandler <- &data - return nil + return e.Websocket.DataHandler.Send(ctx, &data) } -func (e *Exchange) processOptionsSettlementPushData(incoming []byte) error { +func (e *Exchange) processOptionsSettlementPushData(ctx context.Context, incoming []byte) error { var data WsOptionsSettlement err := json.Unmarshal(incoming, &data) if err != nil { return err } - e.Websocket.DataHandler <- &data - return nil + return e.Websocket.DataHandler.Send(ctx, &data) } -func (e *Exchange) processOptionsContractPushData(incoming []byte) error { +func (e *Exchange) processOptionsContractPushData(ctx context.Context, incoming []byte) error { var data WsOptionsContract err := json.Unmarshal(incoming, &data) if err != nil { return err } - e.Websocket.DataHandler <- &data - return nil + return e.Websocket.DataHandler.Send(ctx, &data) } -func (e *Exchange) processOptionsCandlestickPushData(data []byte) error { +func (e *Exchange) processOptionsCandlestickPushData(ctx context.Context, data []byte) error { resp := struct { Time types.Time `json:"time"` Channel string `json:"channel"` @@ -477,18 +470,16 @@ func (e *Exchange) processOptionsCandlestickPushData(data []byte) error { Volume: resp.Result[x].Amount.Float64(), } } - e.Websocket.DataHandler <- klineDatas - return nil + return e.Websocket.DataHandler.Send(ctx, klineDatas) } -func (e *Exchange) processOrderbookTickerPushData(incoming []byte) error { +func (e *Exchange) processOrderbookTickerPushData(ctx context.Context, incoming []byte) error { var data WsOptionsOrderbookTicker err := json.Unmarshal(incoming, &data) if err != nil { return err } - e.Websocket.DataHandler <- &data - return nil + return e.Websocket.DataHandler.Send(ctx, &data) } func (e *Exchange) processOptionsOrderbookUpdate(ctx context.Context, incoming []byte, a asset.Item, pushTime time.Time) error { @@ -594,7 +585,7 @@ func (e *Exchange) processOptionsOrderbookSnapshotPushData(event string, incomin return nil } -func (e *Exchange) processOptionsOrderPushData(data []byte) error { +func (e *Exchange) processOptionsOrderPushData(ctx context.Context, data []byte) error { resp := struct { Time types.Time `json:"time"` Channel string `json:"channel"` @@ -629,8 +620,7 @@ func (e *Exchange) processOptionsOrderPushData(data []byte) error { AccountID: resp.Result[x].User, } } - e.Websocket.DataHandler <- orderDetails - return nil + return e.Websocket.DataHandler.Send(ctx, orderDetails) } func (e *Exchange) processOptionsUserTradesPushData(data []byte) error { @@ -662,7 +652,7 @@ func (e *Exchange) processOptionsUserTradesPushData(data []byte) error { return e.Websocket.Fills.Update(fills...) } -func (e *Exchange) processOptionsLiquidatesPushData(data []byte) error { +func (e *Exchange) processOptionsLiquidatesPushData(ctx context.Context, data []byte) error { resp := struct { Time types.Time `json:"time"` Channel string `json:"channel"` @@ -673,11 +663,10 @@ func (e *Exchange) processOptionsLiquidatesPushData(data []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- &resp - return nil + return e.Websocket.DataHandler.Send(ctx, &resp) } -func (e *Exchange) processOptionsUsersPersonalSettlementsPushData(data []byte) error { +func (e *Exchange) processOptionsUsersPersonalSettlementsPushData(ctx context.Context, data []byte) error { resp := struct { Time types.Time `json:"time"` Channel string `json:"channel"` @@ -688,11 +677,10 @@ func (e *Exchange) processOptionsUsersPersonalSettlementsPushData(data []byte) e if err != nil { return err } - e.Websocket.DataHandler <- &resp - return nil + return e.Websocket.DataHandler.Send(ctx, &resp) } -func (e *Exchange) processOptionsPositionPushData(data []byte) error { +func (e *Exchange) processOptionsPositionPushData(ctx context.Context, data []byte) error { resp := struct { Time types.Time `json:"time"` Channel string `json:"channel"` @@ -703,6 +691,5 @@ func (e *Exchange) processOptionsPositionPushData(data []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- &resp - return nil + return e.Websocket.DataHandler.Send(ctx, &resp) } diff --git a/exchanges/gateio/gateio_websocket_request_spot_test.go b/exchanges/gateio/gateio_websocket_request_spot_test.go index a1f02b2c521..ebcb6f7facc 100644 --- a/exchanges/gateio/gateio_websocket_request_spot_test.go +++ b/exchanges/gateio/gateio_websocket_request_spot_test.go @@ -214,6 +214,6 @@ func newExchangeWithWebsocket(t *testing.T, a asset.Item) *Exchange { } } - require.NoError(t, e.Websocket.Connect()) + require.NoError(t, e.Websocket.Connect(t.Context())) return e } diff --git a/exchanges/gateio/gateio_websocket_test.go b/exchanges/gateio/gateio_websocket_test.go index 6e8cb29614b..cc9d5c938d9 100644 --- a/exchanges/gateio/gateio_websocket_test.go +++ b/exchanges/gateio/gateio_websocket_test.go @@ -197,9 +197,9 @@ func TestProcessBalancePushData(t *testing.T) { //nolint:tparallel // Sequential func checkAccountChange(ctx context.Context, t *testing.T, exch *Exchange, tc *websocketBalancesTest) { t.Helper() - require.Len(t, exch.Websocket.DataHandler, 1) - payload := <-exch.Websocket.DataHandler - received, ok := payload.(accounts.SubAccounts) + require.Len(t, exch.Websocket.DataHandler.C, 1) + payload := <-exch.Websocket.DataHandler.C + received, ok := payload.Data.(accounts.SubAccounts) require.Truef(t, ok, "Expected account changes, got %T", payload) require.Lenf(t, received, len(tc.expected), "Expected %d changes, got %d", len(tc.expected), len(received)) @@ -299,7 +299,7 @@ func TestProcessOrderbookUpdateWithSnapshot(t *testing.T) { }, } { // Sequential tests, do not use t.Parallel(); Some timestamps are deliberately identical from trading activity - err := e.processOrderbookUpdateWithSnapshot(conn, tc.payload, time.Now(), asset.Spot) + err := e.processOrderbookUpdateWithSnapshot(t.Context(), conn, tc.payload, time.Now(), asset.Spot) if tc.err != nil { require.ErrorIs(t, err, tc.err) continue diff --git a/exchanges/gateio/ws_ob_resub_manager.go b/exchanges/gateio/ws_ob_resub_manager.go index ee9947e73d1..d8a8adc3a5f 100644 --- a/exchanges/gateio/ws_ob_resub_manager.go +++ b/exchanges/gateio/ws_ob_resub_manager.go @@ -1,6 +1,7 @@ package gateio import ( + "context" "fmt" "sync" @@ -29,7 +30,7 @@ func (m *wsOBResubManager) IsResubscribing(pair currency.Pair, a asset.Item) boo } // Resubscribe marks a subscription as resubscribing and starts the unsubscribe/resubscribe process -func (m *wsOBResubManager) Resubscribe(e *Exchange, conn websocket.Connection, qualifiedChannel string, pair currency.Pair, a asset.Item) error { +func (m *wsOBResubManager) Resubscribe(ctx context.Context, e *Exchange, conn websocket.Connection, qualifiedChannel string, pair currency.Pair, a asset.Item) error { if err := e.Websocket.Orderbook.InvalidateOrderbook(pair, a); err != nil { return err } @@ -45,7 +46,7 @@ func (m *wsOBResubManager) Resubscribe(e *Exchange, conn websocket.Connection, q m.lookup[key.PairAsset{Base: pair.Base.Item, Quote: pair.Quote.Item, Asset: a}] = true go func() { // Has to be called in routine to not impede websocket throughput - if err := e.Websocket.ResubscribeToChannel(conn, sub); err != nil { + if err := e.Websocket.ResubscribeToChannel(ctx, conn, sub); err != nil { m.CompletedResubscribe(pair, a) // Ensure we clear the map entry on failure too log.Errorf(log.ExchangeSys, "Failed to resubscribe to channel %q: %v", qualifiedChannel, err) } diff --git a/exchanges/gateio/ws_ob_resub_manager_test.go b/exchanges/gateio/ws_ob_resub_manager_test.go index ecc093d399b..e0bf86953b3 100644 --- a/exchanges/gateio/ws_ob_resub_manager_test.go +++ b/exchanges/gateio/ws_ob_resub_manager_test.go @@ -42,7 +42,7 @@ func TestResubscribe(t *testing.T) { require.NoError(t, testexch.Setup(e)) e.Name = "Resubscribe" - err := m.Resubscribe(e, conn, "notfound", currency.NewBTCUSDT(), asset.Spot) + err := m.Resubscribe(t.Context(), e, conn, "notfound", currency.NewBTCUSDT(), asset.Spot) require.ErrorIs(t, err, orderbook.ErrDepthNotFound) require.False(t, m.IsResubscribing(currency.NewBTCUSDT(), asset.Spot)) @@ -55,7 +55,7 @@ func TestResubscribe(t *testing.T) { LastUpdated: time.Now(), }) require.NoError(t, err) - err = m.Resubscribe(e, conn, "notfound", currency.NewBTCUSDT(), asset.Spot) + err = m.Resubscribe(t.Context(), e, conn, "notfound", currency.NewBTCUSDT(), asset.Spot) require.ErrorIs(t, err, subscription.ErrNotFound) require.False(t, m.IsResubscribing(currency.NewBTCUSDT(), asset.Spot)) @@ -78,7 +78,7 @@ func TestResubscribe(t *testing.T) { LastUpdated: time.Now(), }) require.NoError(t, err) - err = m.Resubscribe(e, conn, "ob.BTC_USDT.50", currency.NewBTCUSDT(), asset.Spot) + err = m.Resubscribe(t.Context(), e, conn, "ob.BTC_USDT.50", currency.NewBTCUSDT(), asset.Spot) require.NoError(t, err) assert.True(t, m.IsResubscribing(currency.NewBTCUSDT(), asset.Spot)) } diff --git a/exchanges/gemini/gemini_test.go b/exchanges/gemini/gemini_test.go index 3dd0fe3cb81..72c461dda0d 100644 --- a/exchanges/gemini/gemini_test.go +++ b/exchanges/gemini/gemini_test.go @@ -560,15 +560,14 @@ func TestWsAuth(t *testing.T) { t.Skip(websocket.ErrWebsocketNotEnabled.Error()) } var dialer gws.Dialer - go e.wsReadData() err = e.WsAuth(t.Context(), &dialer) if err != nil { t.Error(err) } timer := time.NewTimer(sharedtestvalues.WebsocketResponseDefaultTimeout) select { - case resp := <-e.Websocket.DataHandler: - subAck, ok := resp.(WsSubscriptionAcknowledgementResponse) + case resp := <-e.Websocket.DataHandler.C: + subAck, ok := resp.Data.(WsSubscriptionAcknowledgementResponse) if !ok { t.Error("unable to type assert WsSubscriptionAcknowledgementResponse") } @@ -587,7 +586,7 @@ func TestWsMissingRole(t *testing.T) { "reason":"MissingRole", "message":"To access this endpoint, you need to log in to the website and go to the settings page to assign one of these roles [FundManager] to API key wujB3szN54gtJ4QDhqRJ which currently has roles [Trader]" }`) - if err := e.wsHandleData(pressXToJSON); err == nil { + if err := e.wsHandleData(t.Context(), pressXToJSON); err == nil { t.Error("Expected error") } } @@ -611,7 +610,7 @@ func TestWsOrderEventSubscriptionResponse(t *testing.T) { "original_amount" : "14.0296", "price" : "1059.54" } ]`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -633,7 +632,7 @@ func TestWsOrderEventSubscriptionResponse(t *testing.T) { "price": "3592.00", "socket_sequence": 13 }]`) - err = e.wsHandleData(pressXToJSON) + err = e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -654,7 +653,7 @@ func TestWsOrderEventSubscriptionResponse(t *testing.T) { "total_spend": "200.00", "socket_sequence": 29 }]`) - err = e.wsHandleData(pressXToJSON) + err = e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -675,7 +674,7 @@ func TestWsOrderEventSubscriptionResponse(t *testing.T) { "original_amount": "25", "socket_sequence": 26 }]`) - err = e.wsHandleData(pressXToJSON) + err = e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -697,7 +696,7 @@ func TestWsOrderEventSubscriptionResponse(t *testing.T) { "original_amount" : "500", "socket_sequence" : 32307 } ]`) - err = e.wsHandleData(pressXToJSON) + err = e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -719,7 +718,7 @@ func TestWsSubAck(t *testing.T) { "closed" ] }`) - if err := e.wsHandleData(pressXToJSON); err != nil { + if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { t.Error(err) } } @@ -732,7 +731,7 @@ func TestWsHeartbeat(t *testing.T) { "trace_id": "b8biknoqppr32kc7gfgg", "socket_sequence": 37 }`) - if err := e.wsHandleData(pressXToJSON); err != nil { + if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { t.Error(err) } } @@ -753,7 +752,7 @@ func TestWsUnsubscribe(t *testing.T) { ]} ] }`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -776,7 +775,7 @@ func TestWsTradeData(t *testing.T) { } ] }`) - if err := e.wsHandleData(pressXToJSON); err != nil { + if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { t.Error(err) } } @@ -808,7 +807,7 @@ func TestWsAuctionData(t *testing.T) { ], "type": "update" }`) - if err := e.wsHandleData(pressXToJSON); err != nil { + if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { t.Error(err) } } @@ -829,7 +828,7 @@ func TestWsBlockTrade(t *testing.T) { } ] }`) - if err := e.wsHandleData(pressXToJSON); err != nil { + if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { t.Error(err) } } @@ -844,7 +843,7 @@ func TestWSTrade(t *testing.T) { "quantity": "0.09110000", "side": "buy" }`) - if err := e.wsHandleData(pressXToJSON); err != nil { + if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { t.Error(err) } } @@ -872,7 +871,7 @@ func TestWsCandles(t *testing.T) { ] ] }`) - if err := e.wsHandleData(pressXToJSON); err != nil { + if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { t.Error(err) } } @@ -894,7 +893,7 @@ func TestWsAuctions(t *testing.T) { ], "type": "update" }`) - if err := e.wsHandleData(pressXToJSON); err != nil { + if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { t.Error(err) } @@ -918,7 +917,7 @@ func TestWsAuctions(t *testing.T) { } ] }`) - if err := e.wsHandleData(pressXToJSON); err != nil { + if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { t.Error(err) } @@ -949,7 +948,7 @@ func TestWsAuctions(t *testing.T) { } ] }`) - if err := e.wsHandleData(pressXToJSON); err != nil { + if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { t.Error(err) } } @@ -978,7 +977,7 @@ func TestWsMarketData(t *testing.T) { } ] } `) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -1006,7 +1005,7 @@ func TestWsMarketData(t *testing.T) { } ] } `) - err = e.wsHandleData(pressXToJSON) + err = e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -1028,7 +1027,7 @@ func TestWsMarketData(t *testing.T) { } ] } `) - err = e.wsHandleData(pressXToJSON) + err = e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -1066,7 +1065,7 @@ func TestWsError(t *testing.T) { } for x := range tt { - err := e.wsHandleData(tt[x].Data) + err := e.wsHandleData(t.Context(), tt[x].Data) if tt[x].ErrorExpected && err != nil && !strings.Contains(err.Error(), tt[x].ErrorShouldContain) { t.Errorf("expected error to contain: %s, got: %s", tt[x].ErrorShouldContain, err.Error(), @@ -1126,7 +1125,7 @@ func TestWsLevel2Update(t *testing.T) { } ] }`) - if err := e.wsHandleData(pressXToJSON); err != nil { + if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { t.Error(err) } } diff --git a/exchanges/gemini/gemini_websocket.go b/exchanges/gemini/gemini_websocket.go index f2f5a6fd772..f94a99999b5 100644 --- a/exchanges/gemini/gemini_websocket.go +++ b/exchanges/gemini/gemini_websocket.go @@ -54,9 +54,6 @@ var subscriptionNames = map[string]string{ subscription.OrderbookChannel: marketDataLevel2, } -// Instantiates a communications channel between websocket connections -var comms = make(chan websocket.Response) - // WsConnect initiates a websocket connection func (e *Exchange) WsConnect() error { ctx := context.TODO() @@ -70,9 +67,8 @@ func (e *Exchange) WsConnect() error { return err } - e.Websocket.Wg.Add(2) - go e.wsReadData() - go e.wsFunnelConnectionData(e.Websocket.Conn) + e.Websocket.Wg.Add(1) + go e.wsReadData(ctx, e.Websocket.Conn) if e.Websocket.CanUseAuthenticatedEndpoints() { err := e.WsAuth(ctx, &dialer) @@ -173,54 +169,26 @@ func (e *Exchange) WsAuth(ctx context.Context, dialer *gws.Dialer) error { return fmt.Errorf("%v Websocket connection %v error. Error %v", e.Name, endpoint, err) } e.Websocket.Wg.Add(1) - go e.wsFunnelConnectionData(e.Websocket.AuthConn) + go e.wsReadData(ctx, e.Websocket.AuthConn) return nil } -// wsFunnelConnectionData receives data from multiple connections and passes it to wsReadData -func (e *Exchange) wsFunnelConnectionData(ws websocket.Connection) { +func (e *Exchange) wsReadData(ctx context.Context, ws websocket.Connection) { defer e.Websocket.Wg.Done() for { resp := ws.ReadMessage() if resp.Raw == nil { return } - comms <- websocket.Response{Raw: resp.Raw} - } -} - -// wsReadData receives and passes on websocket messages for processing -func (e *Exchange) wsReadData() { - defer e.Websocket.Wg.Done() - for { - select { - case <-e.Websocket.ShutdownC: - select { - case resp := <-comms: - err := e.wsHandleData(resp.Raw) - if err != nil { - select { - case e.Websocket.DataHandler <- err: - default: - log.Errorf(log.WebsocketMgr, - "%s websocket handle data error: %v", - e.Name, - err) - } - } - default: - } - return - case resp := <-comms: - err := e.wsHandleData(resp.Raw) - if err != nil { - e.Websocket.DataHandler <- err + if err := e.wsHandleData(ctx, resp.Raw); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, ws.GetURL(), errSend, err) } } } } -func (e *Exchange) wsHandleData(respRaw []byte) error { +func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { // only order details are sent in arrays if strings.HasPrefix(string(respRaw), "[") { var result []WsOrderResponse @@ -232,29 +200,17 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { for i := range result { oSide, err := order.StringToOrderSide(result[i].Side) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: result[i].OrderID, - Err: err, - } + return err } var oType order.Type oType, err = stringToOrderType(result[i].OrderType) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: result[i].OrderID, - Err: err, - } + return err } var oStatus order.Status oStatus, err = stringToOrderStatus(result[i].Type) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: result[i].OrderID, - Err: err, - } + return err } enabledPairs, err := e.GetAvailablePairs(asset.Spot) @@ -272,7 +228,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { return err } - e.Websocket.DataHandler <- &order.Detail{ + if err := e.Websocket.DataHandler.Send(ctx, &order.Detail{ HiddenOrder: result[i].IsHidden, Price: result[i].Price, Amount: result[i].OriginalAmount, @@ -286,6 +242,8 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { AssetType: asset.Spot, Date: result[i].TimestampMS.Time(), Pair: pair, + }); err != nil { + return err } } return nil @@ -303,7 +261,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { if err != nil { return err } - return e.wsProcessUpdate(l2MarketData) + return e.wsProcessUpdate(ctx, l2MarketData) case "trade": if !e.IsSaveTradeDataEnabled() { return nil @@ -317,10 +275,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { tSide, err := order.StringToOrderSide(result.Side) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - Err: err, - } + return err } enabledPairs, err := e.GetEnabledPairs(asset.Spot) @@ -356,14 +311,14 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- result + return e.Websocket.DataHandler.Send(ctx, result) case "initial": var result WsSubscriptionAcknowledgementResponse err := json.Unmarshal(respRaw, &result) if err != nil { return err } - e.Websocket.DataHandler <- result + return e.Websocket.DataHandler.Send(ctx, result) case "heartbeat": return nil case "candles_1m_updates", @@ -401,7 +356,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { if !ok { return errors.New("unable to type assert interval") } - e.Websocket.DataHandler <- websocket.KlineData{ + if err := e.Websocket.DataHandler.Send(ctx, websocket.KlineData{ Timestamp: time.UnixMilli(int64(candle.Changes[i][0])), Pair: pair, AssetType: asset.Spot, @@ -412,11 +367,12 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { LowPrice: candle.Changes[i][3], ClosePrice: candle.Changes[i][4], Volume: candle.Changes[i][5], + }); err != nil { + return err } } default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)} - return nil + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)}) } } else if r, ok := result["result"].(string); ok { switch r { @@ -429,8 +385,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { } return fmt.Errorf("%v Unhandled websocket error %s", e.Name, respRaw) default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)} - return nil + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)}) } } return nil @@ -468,7 +423,7 @@ func stringToOrderType(oType string) (order.Type, error) { } } -func (e *Exchange) wsProcessUpdate(result *wsL2MarketData) error { +func (e *Exchange) wsProcessUpdate(ctx context.Context, result *wsL2MarketData) error { isInitial := len(result.Changes) > 0 && len(result.Trades) > 0 enabledPairs, err := e.GetEnabledPairs(asset.Spot) if err != nil { @@ -538,7 +493,9 @@ func (e *Exchange) wsProcessUpdate(result *wsL2MarketData) error { } if len(result.AuctionEvents) > 0 { - e.Websocket.DataHandler <- result.AuctionEvents + if err := e.Websocket.DataHandler.Send(ctx, result.AuctionEvents); err != nil { + return err + } } if !e.IsSaveTradeDataEnabled() { @@ -549,10 +506,7 @@ func (e *Exchange) wsProcessUpdate(result *wsL2MarketData) error { for x := range result.Trades { tSide, err := order.StringToOrderSide(result.Trades[x].Side) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - Err: err, - } + return err } trades[x] = trade.Data{ Timestamp: result.Trades[x].Timestamp.Time(), diff --git a/exchanges/hitbtc/hitbtc_test.go b/exchanges/hitbtc/hitbtc_test.go index 5ae57be18ef..410d4e05526 100644 --- a/exchanges/hitbtc/hitbtc_test.go +++ b/exchanges/hitbtc/hitbtc_test.go @@ -437,14 +437,14 @@ func setupWsAuth(t *testing.T) { if err != nil { t.Fatal(err) } - go e.wsReadData() + go e.wsReadData(t.Context()) err = e.wsLogin(t.Context()) if err != nil { t.Fatal(err) } timer := time.NewTimer(time.Second) select { - case loginError := <-e.Websocket.DataHandler: + case loginError := <-e.Websocket.DataHandler.C: t.Fatal(loginError) case <-timer.C: } @@ -546,7 +546,7 @@ func TestWsGetActiveOrdersJSON(t *testing.T) { } ] }`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -570,7 +570,7 @@ func TestWsGetCurrenciesJSON(t *testing.T) { }, "id": "c4ce77f5-1c50-435a-b623-4961191ca129" }`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -591,7 +591,7 @@ func TestWsGetSymbolsJSON(t *testing.T) { }, "id": "1c847290-b366-412b-b8f5-dc630ed5b147" }`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -614,7 +614,7 @@ func TestWsTicker(t *testing.T) { "symbol": "BTCUSD" } }`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -658,7 +658,7 @@ func TestWsOrderbook(t *testing.T) { "timestamp": "2018-11-19T05:00:28.193Z" } }`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -688,7 +688,7 @@ func TestWsOrderbook(t *testing.T) { "timestamp": "2018-11-19T05:00:28.700Z" } }`) - err = e.wsHandleData(pressXToJSON) + err = e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -719,7 +719,7 @@ func TestWsOrderNotification(t *testing.T) { "tradeFee": "-0.000000005" } }`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -746,7 +746,7 @@ func TestWsSubmitOrderJSON(t *testing.T) { }, "id": "99f55c70-1166-49a7-87e9-3b54a00ad893" }`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -773,7 +773,7 @@ func TestWsCancelOrderJSON(t *testing.T) { }, "id": "2ce46937-2770-4453-ac99-ee87939bf5bb" }`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -801,7 +801,7 @@ func TestWsCancelReplaceJSON(t *testing.T) { }, "id": "91e925d3-3b95-4e29-8ae7-938fd5006709" }`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -829,7 +829,7 @@ func TestWsGetTradesRequestResponse(t *testing.T) { ], "id": "4b1f1391-215e-4d12-972c-5cea9d50edf4" }`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -859,7 +859,7 @@ func TestWsGetActiveOrdersRequestJSON(t *testing.T) { ], "id": "9e67b440-2eec-445a-be3a-e81f962c8391" }`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -896,7 +896,7 @@ func TestWsTrades(t *testing.T) { "symbol": "BTCUSD" } }`) - err := e.wsHandleData(pressXToJSON) + err := e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } @@ -917,7 +917,7 @@ func TestWsTrades(t *testing.T) { "symbol": "BTCUSD" } } `) - err = e.wsHandleData(pressXToJSON) + err = e.wsHandleData(t.Context(), pressXToJSON) if err != nil { t.Error(err) } diff --git a/exchanges/hitbtc/hitbtc_websocket.go b/exchanges/hitbtc/hitbtc_websocket.go index 949a70914f0..21ff2e65c3c 100644 --- a/exchanges/hitbtc/hitbtc_websocket.go +++ b/exchanges/hitbtc/hitbtc_websocket.go @@ -64,7 +64,7 @@ func (e *Exchange) WsConnect() error { } e.Websocket.Wg.Add(1) - go e.wsReadData() + go e.wsReadData(ctx) if e.Websocket.CanUseAuthenticatedEndpoints() { err = e.wsLogin(ctx) @@ -77,7 +77,7 @@ func (e *Exchange) WsConnect() error { } // wsReadData receives and passes on websocket messages for processing -func (e *Exchange) wsReadData() { +func (e *Exchange) wsReadData(ctx context.Context) { defer e.Websocket.Wg.Done() for { @@ -86,14 +86,15 @@ func (e *Exchange) wsReadData() { return } - err := e.wsHandleData(resp.Raw) - if err != nil { - e.Websocket.DataHandler <- err + if err := e.wsHandleData(ctx, resp.Raw); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } } } -func (e *Exchange) wsGetTableName(respRaw []byte) (string, error) { +func (e *Exchange) wsGetTableName(ctx context.Context, respRaw []byte) (string, error) { var init capture err := json.Unmarshal(respRaw, &init) if err != nil { @@ -132,8 +133,7 @@ func (e *Exchange) wsGetTableName(respRaw []byte) (string, error) { } case []any: if len(resultType) == 0 { - e.Websocket.DataHandler <- fmt.Sprintf("No data returned. ID: %v", init.ID) - return "", nil + return "", fmt.Errorf("no data returned. ID: %v", init.ID) } data, ok := resultType[0].(map[string]any) @@ -146,12 +146,11 @@ func (e *Exchange) wsGetTableName(respRaw []byte) (string, error) { return "trading", nil } } - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)} - return "", nil + return "", e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)}) } -func (e *Exchange) wsHandleData(respRaw []byte) error { - name, err := e.wsGetTableName(respRaw) +func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { + name, err := e.wsGetTableName(ctx, respRaw) if err != nil { return err } @@ -182,7 +181,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { return err } - e.Websocket.DataHandler <- &ticker.Price{ + return e.Websocket.DataHandler.Send(ctx, &ticker.Price{ ExchangeName: e.Name, Open: wsTicker.Params.Open, Volume: wsTicker.Params.Volume, @@ -195,7 +194,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { LastUpdated: wsTicker.Params.Timestamp, AssetType: asset.Spot, Pair: p, - } + }) case "snapshotOrderbook": var obSnapshot WsOrderbook err := json.Unmarshal(respRaw, &obSnapshot) @@ -260,7 +259,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { return err } for i := range o.Params { - err = e.wsHandleOrderData(&o.Params[i]) + err = e.wsHandleOrderData(ctx, &o.Params[i]) if err != nil { return err } @@ -271,14 +270,14 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- trades + return e.Websocket.DataHandler.Send(ctx, trades) case "report": var o wsReportResponse err := json.Unmarshal(respRaw, &o) if err != nil { return err } - err = e.wsHandleOrderData(&o.OrderData) + err = e.wsHandleOrderData(ctx, &o.OrderData) if err != nil { return err } @@ -289,7 +288,7 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { return err } for i := range o.OrderData { - err = e.wsHandleOrderData(&o.OrderData[i]) + err = e.wsHandleOrderData(ctx, &o.OrderData[i]) if err != nil { return err } @@ -300,13 +299,12 @@ func (e *Exchange) wsHandleData(respRaw []byte) error { if err != nil { return err } - err = e.wsHandleOrderData(&o.OrderData) + err = e.wsHandleOrderData(ctx, &o.OrderData) if err != nil { return err } default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)} - return nil + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)}) } return nil } @@ -344,11 +342,8 @@ func (e *Exchange) WsProcessOrderbookSnapshot(ob *WsOrderbook) error { return err } - p, err := currency.NewPairFromFormattedPairs(ob.Params.Symbol, - pairs, - format) + p, err := currency.NewPairFromFormattedPairs(ob.Params.Symbol, pairs, format) if err != nil { - e.Websocket.DataHandler <- err return err } @@ -361,7 +356,7 @@ func (e *Exchange) WsProcessOrderbookSnapshot(ob *WsOrderbook) error { return e.Websocket.Orderbook.LoadSnapshot(&newOrderBook) } -func (e *Exchange) wsHandleOrderData(o *wsOrderData) error { +func (e *Exchange) wsHandleOrderData(ctx context.Context, o *wsOrderData) error { var trades []order.TradeHistory if o.TradeID > 0 { trades = append(trades, order.TradeHistory{ @@ -375,37 +370,21 @@ func (e *Exchange) wsHandleOrderData(o *wsOrderData) error { } oType, err := order.StringToOrderType(o.Type) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: o.ID, - Err: err, - } + return err } o.Status = strings.Replace(o.Status, "canceled", "cancelled", 1) oStatus, err := order.StringToOrderStatus(o.Status) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: o.ID, - Err: err, - } + return err } oSide, err := order.StringToOrderSide(o.Side) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: o.ID, - Err: err, - } + return err } p, err := currency.NewPairFromString(o.Symbol) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: o.ID, - Err: err, - } + return err } var a asset.Item @@ -413,7 +392,7 @@ func (e *Exchange) wsHandleOrderData(o *wsOrderData) error { if err != nil { return err } - e.Websocket.DataHandler <- &order.Detail{ + return e.Websocket.DataHandler.Send(ctx, &order.Detail{ Price: o.Price, Amount: o.Quantity, ExecutedAmount: o.CumQuantity, @@ -428,8 +407,7 @@ func (e *Exchange) wsHandleOrderData(o *wsOrderData) error { LastUpdated: o.UpdatedAt, Pair: p, Trades: trades, - } - return nil + }) } // WsProcessOrderbookUpdate updates a local cache diff --git a/exchanges/huobi/huobi_test.go b/exchanges/huobi/huobi_test.go index 84ec79ecff6..89dd6ef0112 100644 --- a/exchanges/huobi/huobi_test.go +++ b/exchanges/huobi/huobi_test.go @@ -1294,10 +1294,10 @@ func TestWSCandles(t *testing.T) { err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Key: "market.btcusdt.kline.1min", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.CandlesChannel}) require.NoError(t, err, "AddSubscriptions must not error") testexch.FixtureToDataHandler(t, "testdata/wsCandles.json", e.wsHandleData) - close(e.Websocket.DataHandler) - require.Len(t, e.Websocket.DataHandler, 1, "Must see correct number of records") - cAny := <-e.Websocket.DataHandler - c, ok := cAny.(websocket.KlineData) + e.Websocket.DataHandler.Close() + require.Len(t, e.Websocket.DataHandler.C, 1, "Must see correct number of records") + cAny := <-e.Websocket.DataHandler.C + c, ok := cAny.Data.(websocket.KlineData) require.True(t, ok, "Must get the correct type from DataHandler") exp := websocket.KlineData{ Timestamp: time.UnixMilli(1489474082831), @@ -1321,10 +1321,10 @@ func TestWSOrderbook(t *testing.T) { err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Key: "market.btcusdt.depth.step0", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.OrderbookChannel}) require.NoError(t, err, "AddSubscriptions must not error") testexch.FixtureToDataHandler(t, "testdata/wsOrderbook.json", e.wsHandleData) - close(e.Websocket.DataHandler) - require.Len(t, e.Websocket.DataHandler, 1, "Must see correct number of records") - dAny := <-e.Websocket.DataHandler - d, ok := dAny.(*orderbook.Depth) + e.Websocket.DataHandler.Close() + require.Len(t, e.Websocket.DataHandler.C, 1, "Must see correct number of records") + dAny := <-e.Websocket.DataHandler.C + d, ok := dAny.Data.(*orderbook.Depth) require.True(t, ok, "Must get the correct type from DataHandler") require.NotNil(t, d) l, err := d.GetAskLength() @@ -1350,7 +1350,7 @@ func TestWSHandleAllTradesMsg(t *testing.T) { require.NoError(t, err, "AddSubscriptions must not error") e.SetSaveTradeDataStatus(true) testexch.FixtureToDataHandler(t, "testdata/wsAllTrades.json", e.wsHandleData) - close(e.Websocket.DataHandler) + e.Websocket.DataHandler.Close() exp := []trade.Data{ { Exchange: e.Name, @@ -1373,11 +1373,11 @@ func TestWSHandleAllTradesMsg(t *testing.T) { AssetType: asset.Spot, }, } - require.Len(t, e.Websocket.DataHandler, 2, "Must see correct number of trades") - for resp := range e.Websocket.DataHandler { - switch v := resp.(type) { + require.Len(t, e.Websocket.DataHandler.C, 2, "Must see correct number of trades") + for resp := range e.Websocket.DataHandler.C { + switch v := resp.Data.(type) { case trade.Data: - i := 1 - len(e.Websocket.DataHandler) + i := 1 - len(e.Websocket.DataHandler.C) require.Equalf(t, exp[i], v, "Trade [%d] must be correct", i) case error: t.Error(v) @@ -1385,7 +1385,7 @@ func TestWSHandleAllTradesMsg(t *testing.T) { t.Errorf("Unexpected type in DataHandler: %T(%s)", v, v) } } - require.Empty(t, e.Websocket.DataHandler, "Must not see any errors going to datahandler") + require.Empty(t, e.Websocket.DataHandler.C, "Must not see any errors going to datahandler") } func TestWSTicker(t *testing.T) { @@ -1395,10 +1395,10 @@ func TestWSTicker(t *testing.T) { err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Key: "market.btcusdt.detail", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.TickerChannel}) require.NoError(t, err, "AddSubscriptions must not error") testexch.FixtureToDataHandler(t, "testdata/wsTicker.json", e.wsHandleData) - close(e.Websocket.DataHandler) - require.Len(t, e.Websocket.DataHandler, 1, "Must see correct number of records") - tickAny := <-e.Websocket.DataHandler - tick, ok := tickAny.(*ticker.Price) + e.Websocket.DataHandler.Close() + require.Len(t, e.Websocket.DataHandler.C, 1, "Must see correct number of records") + tickAny := <-e.Websocket.DataHandler.C + tick, ok := tickAny.Data.(*ticker.Price) require.True(t, ok, "Must get the correct type from DataHandler") require.NotNil(t, tick) exp := &ticker.Price{ @@ -1425,16 +1425,16 @@ func TestWSAccountUpdate(t *testing.T) { require.NoError(t, err, "AddSubscriptions must not error") e.SetSaveTradeDataStatus(true) testexch.FixtureToDataHandler(t, "testdata/wsMyAccount.json", e.wsHandleData) - close(e.Websocket.DataHandler) - require.Len(t, e.Websocket.DataHandler, 3, "Must see correct number of records") + e.Websocket.DataHandler.Close() + require.Len(t, e.Websocket.DataHandler.C, 3, "Must see correct number of records") exp := []WsAccountUpdate{ {Currency: "btc", AccountID: 123456, Balance: 23.111, ChangeType: "transfer", AccountType: "trade", ChangeTime: types.Time(time.UnixMilli(1568601800000)), SeqNum: 1}, {Currency: "btc", AccountID: 33385, Available: 2028.69, ChangeType: "order.match", AccountType: "trade", ChangeTime: types.Time(time.UnixMilli(1574393385167)), SeqNum: 2}, {Currency: "usdt", AccountID: 14884859, Available: 20.29388158, Balance: 20.29388158, AccountType: "trade", SeqNum: 3}, } for _, ex := range exp { - uAny := <-e.Websocket.DataHandler - u, ok := uAny.(WsAccountUpdate) + uAny := <-e.Websocket.DataHandler.C + u, ok := uAny.Data.(WsAccountUpdate) require.True(t, ok, "Must get the correct type from DataHandler") require.NotNil(t, u) assert.Equal(t, ex, u) @@ -1449,10 +1449,10 @@ func TestWSOrderUpdate(t *testing.T) { require.NoError(t, err, "AddSubscriptions must not error") e.SetSaveTradeDataStatus(true) errs := testexch.FixtureToDataHandlerWithErrors(t, "testdata/wsMyOrders.json", e.wsHandleData) - close(e.Websocket.DataHandler) + e.Websocket.DataHandler.Close() require.Equal(t, 1, len(errs), "Must receive the correct number of errors back") require.ErrorContains(t, errs[0].Err, "error with order \"test1\": invalid.client.order.id (NT) (2002)") - require.Len(t, e.Websocket.DataHandler, 4, "Must see correct number of records") + require.Len(t, e.Websocket.DataHandler.C, 4, "Must see correct number of records") exp := []*order.Detail{ { Exchange: e.Name, @@ -1499,9 +1499,9 @@ func TestWSOrderUpdate(t *testing.T) { }, } for _, ex := range exp { - m := <-e.Websocket.DataHandler - require.IsType(t, &order.Detail{}, m, "Must get the correct type from DataHandler") - d, _ := m.(*order.Detail) + m := <-e.Websocket.DataHandler.C + require.IsType(t, &order.Detail{}, m.Data, "Must get the correct type from DataHandler") + d, _ := m.Data.(*order.Detail) require.NotNil(t, d) assert.Equal(t, ex, d, "Order Detail should match") } @@ -1515,9 +1515,9 @@ func TestWSMyTrades(t *testing.T) { require.NoError(t, err, "AddSubscriptions must not error") e.SetSaveTradeDataStatus(true) testexch.FixtureToDataHandler(t, "testdata/wsMyTrades.json", e.wsHandleData) - close(e.Websocket.DataHandler) - require.Len(t, e.Websocket.DataHandler, 1, "Must see correct number of records") - m := <-e.Websocket.DataHandler + e.Websocket.DataHandler.Close() + require.Len(t, e.Websocket.DataHandler.C, 1, "Must see correct number of records") + m := <-e.Websocket.DataHandler.C exp := &order.Detail{ Exchange: e.Name, Pair: btcusdtPair, @@ -1543,8 +1543,8 @@ func TestWSMyTrades(t *testing.T) { }, }, } - require.IsType(t, &order.Detail{}, m, "Must get the correct type from DataHandler") - d, _ := m.(*order.Detail) + require.IsType(t, &order.Detail{}, m.Data, "Must get the correct type from DataHandler") + d, _ := m.Data.(*order.Detail) require.NotNil(t, d) assert.Equal(t, exp, d, "Order Detail should match") } diff --git a/exchanges/huobi/huobi_websocket.go b/exchanges/huobi/huobi_websocket.go index 591f2af50e1..04315f5e5ba 100644 --- a/exchanges/huobi/huobi_websocket.go +++ b/exchanges/huobi/huobi_websocket.go @@ -110,7 +110,9 @@ func (e *Exchange) wsReadMsgs(ctx context.Context, s websocket.Connection) { } if err := e.wsHandleData(ctx, msg.Raw); err != nil { - e.Websocket.DataHandler <- err + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } } } @@ -144,14 +146,12 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { if s == nil { return fmt.Errorf("%w: %q", subscription.ErrNotFound, ch) } - return e.wsHandleChannelMsgs(s, respRaw) + return e.wsHandleChannelMsgs(ctx, s, respRaw) } - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{ Message: e.Name + websocket.UnhandledMessage + string(respRaw), - } - - return nil + }) } // wsHandleV1ping handles v1 style pings, currently only used with public connections @@ -181,27 +181,27 @@ func (e *Exchange) wsHandleV2subResp(action string, respRaw []byte) error { return nil } -func (e *Exchange) wsHandleChannelMsgs(s *subscription.Subscription, respRaw []byte) error { +func (e *Exchange) wsHandleChannelMsgs(ctx context.Context, s *subscription.Subscription, respRaw []byte) error { switch s.Channel { case subscription.TickerChannel: - return e.wsHandleTickerMsg(s, respRaw) + return e.wsHandleTickerMsg(ctx, s, respRaw) case subscription.OrderbookChannel: return e.wsHandleOrderbookMsg(s, respRaw) case subscription.CandlesChannel: - return e.wsHandleCandleMsg(s, respRaw) + return e.wsHandleCandleMsg(ctx, s, respRaw) case subscription.AllTradesChannel: - return e.wsHandleAllTradesMsg(s, respRaw) + return e.wsHandleAllTradesMsg(ctx, s, respRaw) case subscription.MyAccountChannel: - return e.wsHandleMyAccountMsg(respRaw) + return e.wsHandleMyAccountMsg(ctx, respRaw) case subscription.MyOrdersChannel: - return e.wsHandleMyOrdersMsg(s, respRaw) + return e.wsHandleMyOrdersMsg(ctx, s, respRaw) case subscription.MyTradesChannel: - return e.wsHandleMyTradesMsg(s, respRaw) + return e.wsHandleMyTradesMsg(ctx, s, respRaw) } return fmt.Errorf("%w: %s", common.ErrNotYetImplemented, s.Channel) } -func (e *Exchange) wsHandleCandleMsg(s *subscription.Subscription, respRaw []byte) error { +func (e *Exchange) wsHandleCandleMsg(ctx context.Context, s *subscription.Subscription, respRaw []byte) error { if len(s.Pairs) != 1 { return subscription.ErrNotSinglePair } @@ -209,7 +209,7 @@ func (e *Exchange) wsHandleCandleMsg(s *subscription.Subscription, respRaw []byt if err := json.Unmarshal(respRaw, &c); err != nil { return err } - e.Websocket.DataHandler <- websocket.KlineData{ + return e.Websocket.DataHandler.Send(ctx, websocket.KlineData{ Timestamp: c.Timestamp.Time(), Exchange: e.Name, AssetType: s.Asset, @@ -220,11 +220,10 @@ func (e *Exchange) wsHandleCandleMsg(s *subscription.Subscription, respRaw []byt LowPrice: c.Tick.Low, Volume: c.Tick.Volume, Interval: s.Interval.String(), - } - return nil + }) } -func (e *Exchange) wsHandleAllTradesMsg(s *subscription.Subscription, respRaw []byte) error { +func (e *Exchange) wsHandleAllTradesMsg(ctx context.Context, s *subscription.Subscription, respRaw []byte) error { saveTradeData := e.IsSaveTradeDataEnabled() tradeFeed := e.IsTradeFeedEnabled() if !saveTradeData && !tradeFeed { @@ -256,7 +255,9 @@ func (e *Exchange) wsHandleAllTradesMsg(s *subscription.Subscription, respRaw [] } if tradeFeed { for i := range trades { - e.Websocket.DataHandler <- trades[i] + if err := e.Websocket.DataHandler.Send(ctx, trades[i]); err != nil { + return err + } } } if saveTradeData { @@ -265,7 +266,7 @@ func (e *Exchange) wsHandleAllTradesMsg(s *subscription.Subscription, respRaw [] return nil } -func (e *Exchange) wsHandleTickerMsg(s *subscription.Subscription, respRaw []byte) error { +func (e *Exchange) wsHandleTickerMsg(ctx context.Context, s *subscription.Subscription, respRaw []byte) error { if len(s.Pairs) != 1 { return subscription.ErrNotSinglePair } @@ -273,7 +274,7 @@ func (e *Exchange) wsHandleTickerMsg(s *subscription.Subscription, respRaw []byt if err := json.Unmarshal(respRaw, &wsTicker); err != nil { return err } - e.Websocket.DataHandler <- &ticker.Price{ + return e.Websocket.DataHandler.Send(ctx, &ticker.Price{ ExchangeName: e.Name, Open: wsTicker.Tick.Open, Close: wsTicker.Tick.Close, @@ -284,8 +285,7 @@ func (e *Exchange) wsHandleTickerMsg(s *subscription.Subscription, respRaw []byt LastUpdated: wsTicker.Timestamp.Time(), AssetType: s.Asset, Pair: s.Pairs[0], - } - return nil + }) } func (e *Exchange) wsHandleOrderbookMsg(s *subscription.Subscription, respRaw []byte) error { @@ -340,7 +340,7 @@ func (e *Exchange) wsHandleOrderbookMsg(s *subscription.Subscription, respRaw [] return e.Websocket.Orderbook.LoadSnapshot(&newOrderBook) } -func (e *Exchange) wsHandleMyOrdersMsg(s *subscription.Subscription, respRaw []byte) error { +func (e *Exchange) wsHandleMyOrdersMsg(ctx context.Context, s *subscription.Subscription, respRaw []byte) error { var msg wsOrderUpdateMsg if err := json.Unmarshal(respRaw, &msg); err != nil { return err @@ -399,14 +399,16 @@ func (e *Exchange) wsHandleMyOrdersMsg(s *subscription.Subscription, respRaw []b } } } - e.Websocket.DataHandler <- d + if err := e.Websocket.DataHandler.Send(ctx, d); err != nil { + return err + } if o.ErrCode != 0 { return fmt.Errorf("error with order %q: %s (%v)", o.ClientOrderID, o.ErrMessage, o.ErrCode) } return nil } -func (e *Exchange) wsHandleMyTradesMsg(s *subscription.Subscription, respRaw []byte) error { +func (e *Exchange) wsHandleMyTradesMsg(ctx context.Context, s *subscription.Subscription, respRaw []byte) error { var msg wsTradeUpdateMsg if err := json.Unmarshal(respRaw, &msg); err != nil { return err @@ -468,17 +470,15 @@ func (e *Exchange) wsHandleMyTradesMsg(s *subscription.Subscription, respRaw []b Timestamp: t.TradeTime.Time(), }, } - e.Websocket.DataHandler <- d - return nil + return e.Websocket.DataHandler.Send(ctx, d) } -func (e *Exchange) wsHandleMyAccountMsg(respRaw []byte) error { +func (e *Exchange) wsHandleMyAccountMsg(ctx context.Context, respRaw []byte) error { u := &wsAccountUpdateMsg{} if err := json.Unmarshal(respRaw, u); err != nil { return err } - e.Websocket.DataHandler <- u.Data - return nil + return e.Websocket.DataHandler.Send(ctx, u.Data) } // generateSubscriptions returns a list of subscriptions from the configured subscriptions feature diff --git a/exchanges/huobi/huobi_wrapper.go b/exchanges/huobi/huobi_wrapper.go index 9b4f1703021..1dfe4e736de 100644 --- a/exchanges/huobi/huobi_wrapper.go +++ b/exchanges/huobi/huobi_wrapper.go @@ -1233,39 +1233,15 @@ func (e *Exchange) GetOrderInfo(ctx context.Context, orderID string, pair curren typeDetails := strings.Split(respData.Type, "-") orderSide, err := order.StringToOrderSide(typeDetails[0]) if err != nil { - if e.Websocket.IsConnected() { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } - } else { - return nil, err - } + return nil, err } orderType, err := order.StringToOrderType(typeDetails[1]) if err != nil { - if e.Websocket.IsConnected() { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } - } else { - return nil, err - } + return nil, err } orderStatus, err := order.StringToOrderStatus(respData.State) if err != nil { - if e.Websocket.IsConnected() { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: orderID, - Err: err, - } - } else { - return nil, err - } + return nil, err } var p currency.Pair var a asset.Item diff --git a/exchanges/kraken/kraken_test.go b/exchanges/kraken/kraken_test.go index bf65a1022ac..c7d2118ea1a 100644 --- a/exchanges/kraken/kraken_test.go +++ b/exchanges/kraken/kraken_test.go @@ -1003,7 +1003,7 @@ func TestWsResubscribe(t *testing.T) { err = subs[0].SetState(subscription.UnsubscribingState) require.NoError(t, err) - err = e.Websocket.ResubscribeToChannel(e.Websocket.Conn, subs[0]) + err = e.Websocket.ResubscribeToChannel(t.Context(), e.Websocket.Conn, subs[0]) require.NoError(t, err, "Resubscribe must not error") require.Equal(t, subscription.SubscribedState, subs[0].State(), "subscription must be subscribed again") } @@ -1202,25 +1202,25 @@ func TestWSProcessTrades(t *testing.T) { err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{spotTestPair}, Channel: subscription.AllTradesChannel, Key: 18788}) require.NoError(t, err, "AddSubscriptions must not error") testexch.FixtureToDataHandler(t, "testdata/wsAllTrades.json", e.wsHandleData) - close(e.Websocket.DataHandler) + e.Websocket.DataHandler.Close() invalid := []any{"trades", []any{[]any{"95873.80000", "0.00051182", "1708731380.3791859"}}} rawBytes, err := json.Marshal(invalid) require.NoError(t, err, "Marshal must not error marshalling invalid trade data") pair := currency.NewPair(currency.XBT, currency.USD) - err = e.wsProcessTrades(json.RawMessage(rawBytes), pair) + err = e.wsProcessTrades(t.Context(), json.RawMessage(rawBytes), pair) require.ErrorContains(t, err, "error unmarshalling trade data") expJSON := []string{ `{"AssetType":"spot","CurrencyPair":"XBT/USD","Side":"BUY","Price":95873.80000,"Amount":0.00051182,"Timestamp":"2025-02-23T23:29:40.379186Z"}`, `{"AssetType":"spot","CurrencyPair":"XBT/USD","Side":"SELL","Price":95940.90000,"Amount":0.00011069,"Timestamp":"2025-02-24T02:01:12.853682Z"}`, } - require.Len(t, e.Websocket.DataHandler, len(expJSON), "Must see correct number of trades") - for resp := range e.Websocket.DataHandler { - switch v := resp.(type) { + require.Len(t, e.Websocket.DataHandler.C, len(expJSON), "Must see correct number of trades") + for resp := range e.Websocket.DataHandler.C { + switch v := resp.Data.(type) { case trade.Data: - i := 1 - len(e.Websocket.DataHandler) + i := 1 - len(e.Websocket.DataHandler.C) exp := trade.Data{Exchange: e.Name, CurrencyPair: spotTestPair} require.NoErrorf(t, json.Unmarshal([]byte(expJSON[i]), &exp), "Must not error unmarshalling json %d: %s", i, expJSON[i]) require.Equalf(t, exp, v, "Trade [%d] must be correct", i) @@ -1238,12 +1238,12 @@ func TestWsOpenOrders(t *testing.T) { require.NoError(t, testexch.Setup(e), "Test instance Setup must not error") testexch.UpdatePairsOnce(t, e) testexch.FixtureToDataHandler(t, "testdata/wsOpenTrades.json", e.wsHandleData) - close(e.Websocket.DataHandler) - assert.Len(t, e.Websocket.DataHandler, 7, "Should see 7 orders") - for resp := range e.Websocket.DataHandler { - switch v := resp.(type) { + e.Websocket.DataHandler.Close() + assert.Len(t, e.Websocket.DataHandler.C, 7, "Should see 7 orders") + for resp := range e.Websocket.DataHandler.C { + switch v := resp.Data.(type) { case *order.Detail: - switch len(e.Websocket.DataHandler) { + switch len(e.Websocket.DataHandler.C) { case 6: assert.Equal(t, "OGTT3Y-C6I3P-XRI6HR", v.OrderID, "OrderID") assert.Equal(t, order.Limit, v.Type, "order type") diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index 0841e69d25c..2fa97dda80f 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -105,10 +105,8 @@ func (e *Exchange) WsConnect() error { return err } - comms := make(chan websocket.Response) - e.Websocket.Wg.Add(2) - go e.wsReadData(ctx, comms) - go e.wsFunnelConnectionData(e.Websocket.Conn, comms) + e.Websocket.Wg.Add(1) + go e.wsReadData(ctx, e.Websocket.Conn) if e.IsWebsocketAuthenticationSupported() { if authToken, err := e.GetWebsocketToken(ctx); err != nil { @@ -122,7 +120,7 @@ func (e *Exchange) WsConnect() error { e.setWebsocketAuthToken(authToken) e.Websocket.SetCanUseAuthenticatedEndpoints(true) e.Websocket.Wg.Add(1) - go e.wsFunnelConnectionData(e.Websocket.AuthConn, comms) + go e.wsReadData(ctx, e.Websocket.AuthConn) e.startWsPingHandler(e.Websocket.AuthConn) } } @@ -134,47 +132,22 @@ func (e *Exchange) WsConnect() error { } // wsFunnelConnectionData funnels both auth and public ws data into one manageable place -func (e *Exchange) wsFunnelConnectionData(ws websocket.Connection, comms chan websocket.Response) { +func (e *Exchange) wsReadData(ctx context.Context, ws websocket.Connection) { defer e.Websocket.Wg.Done() for { resp := ws.ReadMessage() if resp.Raw == nil { return } - comms <- resp - } -} - -// wsReadData receives and passes on websocket messages for processing -func (e *Exchange) wsReadData(ctx context.Context, comms chan websocket.Response) { - defer e.Websocket.Wg.Done() - - for { - select { - case <-e.Websocket.ShutdownC: - select { - case resp := <-comms: - err := e.wsHandleData(ctx, resp.Raw) - if err != nil { - select { - case e.Websocket.DataHandler <- err: - default: - log.Errorf(log.WebsocketMgr, "%s websocket handle data error: %v", e.Name, err) - } - } - default: - } - return - case resp := <-comms: - err := e.wsHandleData(ctx, resp.Raw) - if err != nil { - e.Websocket.DataHandler <- err + if err := e.wsHandleData(ctx, resp.Raw); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, ws.GetURL(), errSend, err) } } } } -func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { +func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { if strings.HasPrefix(string(respRaw), "[") { var msg []json.RawMessage if err := json.Unmarshal(respRaw, &msg); err != nil { @@ -200,7 +173,7 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { pair = p } - return e.wsReadDataResponse(chanName, pair, msg) + return e.wsReadDataResponse(ctx, chanName, pair, msg) } event, err := jsonparser.GetString(respRaw, "event") @@ -230,12 +203,10 @@ func (e *Exchange) wsHandleData(_ context.Context, respRaw []byte) error { case krakenWsSystemStatus: return e.wsProcessSystemStatus(respRaw) default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{ Message: fmt.Sprintf("%s: %s", websocket.UnhandledMessage, respRaw), - } + }) } - - return nil } // startWsPingHandler sets up a websocket ping handler to maintain a connection @@ -248,26 +219,26 @@ func (e *Exchange) startWsPingHandler(conn websocket.Connection) { } // wsReadDataResponse classifies the WS response and sends to appropriate handler -func (e *Exchange) wsReadDataResponse(c string, pair currency.Pair, response []json.RawMessage) error { +func (e *Exchange) wsReadDataResponse(ctx context.Context, c string, pair currency.Pair, response []json.RawMessage) error { switch c { case krakenWsTicker: - return e.wsProcessTickers(response[1], pair) + return e.wsProcessTickers(ctx, response[1], pair) case krakenWsSpread: return e.wsProcessSpread(response[1], pair) case krakenWsTrade: - return e.wsProcessTrades(response[1], pair) + return e.wsProcessTrades(ctx, response[1], pair) case krakenWsOwnTrades: - return e.wsProcessOwnTrades(response[0]) + return e.wsProcessOwnTrades(ctx, response[0]) case krakenWsOpenOrders: - return e.wsProcessOpenOrders(response[0]) + return e.wsProcessOpenOrders(ctx, response[0]) } channelType := strings.TrimRight(c, "-0123456789") switch channelType { case krakenWsOHLC: - return e.wsProcessCandle(c, response[1], pair) + return e.wsProcessCandle(ctx, c, response[1], pair) case krakenWsOrderbook: - return e.wsProcessOrderBook(c, response, pair) + return e.wsProcessOrderBook(ctx, c, response, pair) default: return fmt.Errorf("received unidentified data for subscription %s: %+v", c, response) } @@ -279,7 +250,7 @@ func (e *Exchange) wsProcessSystemStatus(respRaw []byte) error { return fmt.Errorf("%s parsing system status: %s", err, respRaw) } if systemStatus.Status != "online" { - e.Websocket.DataHandler <- fmt.Errorf("system status not online: %v", systemStatus.Status) + return fmt.Errorf("system status not online: %v", systemStatus.Status) } if systemStatus.Version > krakenWSSupportedVersion { log.Warnf(log.ExchangeSys, "%v New version of Websocket API released. Was %v Now %v", e.Name, krakenWSSupportedVersion, systemStatus.Version) @@ -287,7 +258,7 @@ func (e *Exchange) wsProcessSystemStatus(respRaw []byte) error { return nil } -func (e *Exchange) wsProcessOwnTrades(ownOrdersRaw json.RawMessage) error { +func (e *Exchange) wsProcessOwnTrades(ctx context.Context, ownOrdersRaw json.RawMessage) error { var result []map[string]*WsOwnTrade if err := json.Unmarshal(ownOrdersRaw, &result); err != nil { return err @@ -300,21 +271,13 @@ func (e *Exchange) wsProcessOwnTrades(ownOrdersRaw json.RawMessage) error { for key, val := range result[0] { oSide, err := order.StringToOrderSide(val.Type) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: key, - Err: err, - } + return err } oType, err := order.StringToOrderType(val.OrderType) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: key, - Err: err, - } + return err } - e.Websocket.DataHandler <- &order.Detail{ + if err := e.Websocket.DataHandler.Send(ctx, &order.Detail{ Exchange: e.Name, OrderID: val.OrderTransactionID, Trades: []order.TradeHistory{ @@ -329,13 +292,15 @@ func (e *Exchange) wsProcessOwnTrades(ownOrdersRaw json.RawMessage) error { Timestamp: val.Time.Time(), }, }, + }); err != nil { + return err } } return nil } // wsProcessOpenOrders processes open orders from the websocket response -func (e *Exchange) wsProcessOpenOrders(ownOrdersResp json.RawMessage) error { +func (e *Exchange) wsProcessOpenOrders(ctx context.Context, ownOrdersResp json.RawMessage) error { var result []map[string]*WsOpenOrder if err := json.Unmarshal(ownOrdersResp, &result); err != nil { return err @@ -356,57 +321,28 @@ func (e *Exchange) wsProcessOpenOrders(ownOrdersResp json.RawMessage) error { } if val.Status != "" { - if s, err := order.StringToOrderStatus(val.Status); err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: key, - Err: err, - } - } else { - d.Status = s + var err error + if d.Status, err = order.StringToOrderStatus(val.Status); err != nil { + return err } } if val.Description.Pair != "" { - if strings.Contains(val.Description.Order, "sell") { - d.Side = order.Sell - } else { - if oSide, err := order.StringToOrderSide(val.Description.Type); err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: key, - Err: err, - } - } else { - d.Side = oSide + var err error + d.Side = order.Sell + if !strings.Contains(val.Description.Order, "sell") { + if d.Side, err = order.StringToOrderSide(val.Description.Type); err != nil { + return err } } - - if oType, err := order.StringToOrderType(val.Description.OrderType); err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: key, - Err: err, - } - } else { - d.Type = oType + if d.Type, err = order.StringToOrderType(val.Description.OrderType); err != nil { + return err } - - if p, err := currency.NewPairFromString(val.Description.Pair); err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: key, - Err: err, - } - } else { - d.Pair = p - if d.AssetType, err = e.GetPairAssetType(p); err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: key, - Err: err, - } - } + if d.Pair, err = currency.NewPairFromString(val.Description.Pair); err != nil { + return err + } + if d.AssetType, err = e.GetPairAssetType(d.Pair); err != nil { + return err } } @@ -419,20 +355,22 @@ func (e *Exchange) wsProcessOpenOrders(ownOrdersResp json.RawMessage) error { // Note: Volume and ExecutedVolume are only populated when status is open d.RemainingAmount = val.Volume - val.ExecutedVolume } - e.Websocket.DataHandler <- d + if err := e.Websocket.DataHandler.Send(ctx, d); err != nil { + return err + } } } return nil } // wsProcessTickers converts ticker data and sends it to the datahandler -func (e *Exchange) wsProcessTickers(dataRaw json.RawMessage, pair currency.Pair) error { +func (e *Exchange) wsProcessTickers(ctx context.Context, dataRaw json.RawMessage, pair currency.Pair) error { var t wsTicker if err := json.Unmarshal(dataRaw, &t); err != nil { return fmt.Errorf("error unmarshalling ticker data: %w", err) } - e.Websocket.DataHandler <- &ticker.Price{ + return e.Websocket.DataHandler.Send(ctx, &ticker.Price{ ExchangeName: e.Name, Ask: t.Ask[0].Float64(), Bid: t.Bid[0].Float64(), @@ -443,8 +381,7 @@ func (e *Exchange) wsProcessTickers(dataRaw json.RawMessage, pair currency.Pair) Open: t.Open[0].Float64(), AssetType: asset.Spot, Pair: pair, - } - return nil + }) } // wsProcessSpread converts spread/orderbook data and sends it to the datahandler @@ -467,7 +404,7 @@ func (e *Exchange) wsProcessSpread(rawData json.RawMessage, pair currency.Pair) } // wsProcessTrades converts trade data and sends it to the datahandler -func (e *Exchange) wsProcessTrades(respRaw json.RawMessage, pair currency.Pair) error { +func (e *Exchange) wsProcessTrades(ctx context.Context, respRaw json.RawMessage, pair currency.Pair) error { saveTradeData := e.IsSaveTradeDataEnabled() tradeFeed := e.IsTradeFeedEnabled() if !saveTradeData && !tradeFeed { @@ -497,7 +434,9 @@ func (e *Exchange) wsProcessTrades(respRaw json.RawMessage, pair currency.Pair) } if tradeFeed { for i := range trades { - e.Websocket.DataHandler <- trades[i] + if err := e.Websocket.DataHandler.Send(ctx, trades[i]); err != nil { + return err + } } } if saveTradeData { @@ -515,7 +454,7 @@ func hasKey(raw json.RawMessage, key string) bool { } // wsProcessOrderBook handles both partial and full orderbook updates -func (e *Exchange) wsProcessOrderBook(c string, response []json.RawMessage, pair currency.Pair) error { +func (e *Exchange) wsProcessOrderBook(ctx context.Context, c string, response []json.RawMessage, pair currency.Pair) error { key := &subscription.Subscription{ Channel: c, Asset: asset.Spot, @@ -551,7 +490,7 @@ func (e *Exchange) wsProcessOrderBook(c string, response []json.RawMessage, pair if errors.Is(err, errInvalidChecksum) { log.Debugf(log.Global, "%s Resubscribing to invalid %s orderbook", e.Name, pair) go func() { - if e2 := e.Websocket.ResubscribeToChannel(e.Websocket.Conn, s); e2 != nil && !errors.Is(e2, subscription.ErrInStateAlready) { + if e2 := e.Websocket.ResubscribeToChannel(ctx, e.Websocket.Conn, s); e2 != nil && !errors.Is(e2, subscription.ErrInStateAlready) { log.Errorf(log.ExchangeSys, "%s resubscription failure for %v: %v", e.Name, pair, e2) } }() @@ -696,7 +635,7 @@ func trim(s string) string { } // wsProcessCandle converts candle data and sends it to the data handler -func (e *Exchange) wsProcessCandle(c string, resp json.RawMessage, pair currency.Pair) error { +func (e *Exchange) wsProcessCandle(ctx context.Context, c string, resp json.RawMessage, pair currency.Pair) error { var data wsCandle if err := json.Unmarshal(resp, &data); err != nil { return fmt.Errorf("error unmarshalling candle data: %w", err) @@ -709,7 +648,7 @@ func (e *Exchange) wsProcessCandle(c string, resp json.RawMessage, pair currency } interval := parts[1] - e.Websocket.DataHandler <- websocket.KlineData{ + return e.Websocket.DataHandler.Send(ctx, websocket.KlineData{ AssetType: asset.Spot, Pair: pair, Timestamp: time.Now(), @@ -722,8 +661,7 @@ func (e *Exchange) wsProcessCandle(c string, resp json.RawMessage, pair currency ClosePrice: data.Close.Float64(), Volume: data.Volume.Float64(), Interval: interval, - } - return nil + }) } // GetSubscriptionTemplate returns a subscription channel template @@ -1055,10 +993,12 @@ func (e *Exchange) wsAddOrder(ctx context.Context, req *WsAddOrderRequest) (stri if resp.Status == "error" { return "", errors.New("AddOrder error: " + resp.ErrorMessage) } - e.Websocket.DataHandler <- &order.Detail{ + if err := e.Websocket.DataHandler.Send(ctx, &order.Detail{ Exchange: e.Name, OrderID: resp.TransactionID, Status: order.New, + }); err != nil { + return "", err } return resp.TransactionID, nil } diff --git a/exchanges/kucoin/kucoin_test.go b/exchanges/kucoin/kucoin_test.go index 5a228dc2171..e28694abccf 100644 --- a/exchanges/kucoin/kucoin_test.go +++ b/exchanges/kucoin/kucoin_test.go @@ -2352,8 +2352,8 @@ func TestPushData(t *testing.T) { } return e.wsHandleData(ctx, r) }) - close(e.Websocket.DataHandler) - assert.Len(t, e.Websocket.DataHandler, 29, "Should see correct number of messages") + e.Websocket.DataHandler.Close() + assert.Len(t, e.Websocket.DataHandler.C, 29, "Should see correct number of messages") require.Len(t, fErrs, 1, "Must get exactly one error message") assert.ErrorContains(t, fErrs[0].Err, "cannot save holdings: nil pointer: *accounts.Accounts") } @@ -3004,13 +3004,13 @@ func TestProcessMarketSnapshot(t *testing.T) { t.Parallel() ku := testInstance(t) testexch.FixtureToDataHandler(t, "testdata/wsMarketSnapshot.json", ku.wsHandleData) - close(ku.Websocket.DataHandler) - assert.Len(t, ku.Websocket.DataHandler, 4, "Should see 4 tickers") + ku.Websocket.DataHandler.Close() + assert.Len(t, ku.Websocket.DataHandler.C, 4, "Should see 4 tickers") seenAssetTypes := map[asset.Item]int{} - for resp := range ku.Websocket.DataHandler { - switch v := resp.(type) { + for resp := range ku.Websocket.DataHandler.C { + switch v := resp.Data.(type) { case *ticker.Price: - switch len(ku.Websocket.DataHandler) { + switch len(ku.Websocket.DataHandler.C) { case 3: assert.Equal(t, asset.Margin, v.AssetType, "AssetType") assert.Equal(t, time.UnixMilli(1700555342007), v.LastUpdated, "datetime") @@ -3118,6 +3118,11 @@ func TestSubscribeTickerAll(t *testing.T) { t.Parallel() ku := testInstance(t) + go func() { // drain websocket messages when subscribed to all tickers + for { + <-ku.Websocket.DataHandler.C + } + }() ku.Features.Subscriptions = subscription.List{} testexch.SetupWs(t, ku) @@ -4421,7 +4426,7 @@ func TestGetHistoricalFundingRates(t *testing.T) { func TestProcessFuturesKline(t *testing.T) { t.Parallel() data := fmt.Sprintf(`{"symbol":%q,"candles":["1714964400","63815.1","63890.8","63928.5","63797.8","17553.0","17553"],"time":1714964823722}`, futuresTradablePair.String()) - err := e.processFuturesKline([]byte(data), "1hour") + err := e.processFuturesKline(t.Context(), []byte(data), "1hour") assert.NoError(t, err) } diff --git a/exchanges/kucoin/kucoin_websocket.go b/exchanges/kucoin/kucoin_websocket.go index dbcf63c0cea..d08361d5f58 100644 --- a/exchanges/kucoin/kucoin_websocket.go +++ b/exchanges/kucoin/kucoin_websocket.go @@ -130,8 +130,8 @@ func (e *Exchange) WsConnect() error { if e.Websocket.CanUseAuthenticatedEndpoints() { instances, err = e.GetAuthenticatedInstanceServers(ctx) if err != nil { - e.Websocket.DataHandler <- err e.Websocket.SetCanUseAuthenticatedEndpoints(false) + return err } } if instances == nil { @@ -206,7 +206,9 @@ func (e *Exchange) wsReadData(ctx context.Context) { } err := e.wsHandleData(ctx, resp.Raw) if err != nil { - e.Websocket.DataHandler <- err + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } } } @@ -232,9 +234,9 @@ func (e *Exchange) wsHandleData(ctx context.Context, respData []byte) error { } else { instruments = topicInfo[1] } - return e.processTicker(resp.Data, instruments, topicInfo[0]) + return e.processTicker(ctx, resp.Data, instruments, topicInfo[0]) case marketSnapshotChannel: - return e.processMarketSnapshot(resp.Data, topicInfo[0]) + return e.processMarketSnapshot(ctx, resp.Data, topicInfo[0]) case marketOrderbookChannel: return e.processOrderbookWithDepth(respData, topicInfo[1], topicInfo[0]) case marketOrderbookDepth5Channel, marketOrderbookDepth50Channel: @@ -244,36 +246,36 @@ func (e *Exchange) wsHandleData(ctx context.Context, respData []byte) error { if len(symbolAndInterval) != 2 { return common.ErrMalformedData } - return e.processCandlesticks(resp.Data, symbolAndInterval[0], symbolAndInterval[1], topicInfo[0]) + return e.processCandlesticks(ctx, resp.Data, symbolAndInterval[0], symbolAndInterval[1], topicInfo[0]) case marketMatchChannel: return e.processTradeData(resp.Data, topicInfo[1], topicInfo[0]) case indexPriceIndicatorChannel, markPriceIndicatorChannel: var response WsPriceIndicator - return e.processData(resp.Data, &response) + return e.processData(ctx, resp.Data, &response) case privateSpotTradeOrders: - return e.processOrderChangeEvent(resp.Data, topicInfo[0]) + return e.processOrderChangeEvent(ctx, resp.Data, topicInfo[0]) case accountBalanceChannel: return e.processAccountBalanceChange(ctx, resp.Data) case marginPositionChannel: if resp.Subject == "debt.ratio" { var response WsDebtRatioChange - return e.processData(resp.Data, &response) + return e.processData(ctx, resp.Data, &response) } var response WsPositionStatus - return e.processData(resp.Data, &response) + return e.processData(ctx, resp.Data, &response) case marginLoanChannel: if resp.Subject == "order.done" { var response WsMarginTradeOrderDoneEvent - return e.processData(resp.Data, &response) + return e.processData(ctx, resp.Data, &response) } - return e.processMarginLendingTradeOrderEvent(resp.Data) + return e.processMarginLendingTradeOrderEvent(ctx, resp.Data) case spotMarketAdvancedChannel: - return e.processStopOrderEvent(resp.Data) + return e.processStopOrderEvent(ctx, resp.Data) case futuresTickerChannel: - return e.processFuturesTickerV2(resp.Data) + return e.processFuturesTickerV2(ctx, resp.Data) case futuresExecutionDataChannel: var response WsFuturesExecutionData - return e.processData(resp.Data, &response) + return e.processData(ctx, resp.Data, &response) case futuresOrderbookChannel: if err := e.ensureFuturesOrderbookSnapshotLoaded(ctx, topicInfo[1]); err != nil { return err @@ -288,64 +290,62 @@ func (e *Exchange) wsHandleData(ctx context.Context, respData []byte) error { case futuresContractMarketDataChannel: switch resp.Subject { case "mark.index.price": - return e.processFuturesMarkPriceAndIndexPrice(resp.Data, topicInfo[1]) + return e.processFuturesMarkPriceAndIndexPrice(ctx, resp.Data, topicInfo[1]) case "funding.rate": - return e.processFuturesFundingData(resp.Data, topicInfo[1]) + return e.processFuturesFundingData(ctx, resp.Data, topicInfo[1]) } case futuresSystemAnnouncementChannel: - return e.processFuturesSystemAnnouncement(resp.Data, resp.Subject) + return e.processFuturesSystemAnnouncement(ctx, resp.Data, resp.Subject) case futuresTransactionStatisticsTimerEventChannel: return e.processFuturesTransactionStatistics(resp.Data, topicInfo[1]) case futuresTradeOrderChannel: - return e.processFuturesPrivateTradeOrders(resp.Data) + return e.processFuturesPrivateTradeOrders(ctx, resp.Data) case futuresStopOrdersLifecycleEventChannel: - return e.processFuturesStopOrderLifecycleEvent(resp.Data) + return e.processFuturesStopOrderLifecycleEvent(ctx, resp.Data) case futuresAccountBalanceEventChannel: switch resp.Subject { case "orderMargin.change": var response WsFuturesOrderMarginEvent - return e.processData(resp.Data, &response) + return e.processData(ctx, resp.Data, &response) case "availableBalance.change": return e.processFuturesAccountBalanceEvent(ctx, resp.Data) case "withdrawHold.change": var response WsFuturesWithdrawalAmountAndTransferOutAmountEvent - return e.processData(resp.Data, &response) + return e.processData(ctx, resp.Data, &response) } case futuresPositionChangeEventChannel: switch resp.Subject { case "position.change": if resp.ChannelType == "private" { var response WsFuturesPosition - return e.processData(resp.Data, &response) + return e.processData(ctx, resp.Data, &response) } var response WsFuturesMarkPricePositionChanges - return e.processData(resp.Data, &response) + return e.processData(ctx, resp.Data, &response) case "position.settlement": var response WsFuturesPositionFundingSettlement - return e.processData(resp.Data, &response) + return e.processData(ctx, resp.Data, &response) } case futuresLimitCandles: instrumentInfos := strings.Split(topicInfo[1], "_") if len(instrumentInfos) != 2 { return errors.New("invalid instrument information") } - return e.processFuturesKline(resp.Data, instrumentInfos[1]) + return e.processFuturesKline(ctx, resp.Data, instrumentInfos[1]) default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{ Message: e.Name + websocket.UnhandledMessage + string(respData), - } - return errors.New("push data not handled") + }) } return nil } // processData used to deserialize and forward the data to DataHandler. -func (e *Exchange) processData(respData []byte, resp any) error { +func (e *Exchange) processData(ctx context.Context, respData []byte, resp any) error { if err := json.Unmarshal(respData, &resp); err != nil { return err } - e.Websocket.DataHandler <- resp - return nil + return e.Websocket.DataHandler.Send(ctx, resp) } // processFuturesAccountBalanceEvent used to process futures account balance change incoming data. @@ -364,12 +364,11 @@ func (e *Exchange) processFuturesAccountBalanceEvent(ctx context.Context, respDa if err := e.Accounts.Save(ctx, subAccts, false); err != nil { return err } - e.Websocket.DataHandler <- subAccts - return nil + return e.Websocket.DataHandler.Send(ctx, subAccts) } // processFuturesStopOrderLifecycleEvent processes futures stop orders lifecycle events. -func (e *Exchange) processFuturesStopOrderLifecycleEvent(respData []byte) error { +func (e *Exchange) processFuturesStopOrderLifecycleEvent(ctx context.Context, respData []byte) error { resp := WsStopOrderLifecycleEvent{} err := json.Unmarshal(respData, &resp) if err != nil { @@ -392,7 +391,7 @@ func (e *Exchange) processFuturesStopOrderLifecycleEvent(respData []byte) error if err != nil { return err } - e.Websocket.DataHandler <- &order.Detail{ + return e.Websocket.DataHandler.Send(ctx, &order.Detail{ Price: resp.OrderPrice, TriggerPrice: resp.StopPrice, Amount: resp.Size, @@ -404,12 +403,11 @@ func (e *Exchange) processFuturesStopOrderLifecycleEvent(respData []byte) error Date: resp.CreatedAt.Time(), LastUpdated: resp.Timestamp.Time(), Pair: pair, - } - return nil + }) } // processFuturesPrivateTradeOrders processes futures private trade orders updates. -func (e *Exchange) processFuturesPrivateTradeOrders(respData []byte) error { +func (e *Exchange) processFuturesPrivateTradeOrders(ctx context.Context, respData []byte) error { resp := WsFuturesTradeOrder{} if err := json.Unmarshal(respData, &resp); err != nil { return err @@ -435,7 +433,7 @@ func (e *Exchange) processFuturesPrivateTradeOrders(respData []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- &order.Detail{ + return e.Websocket.DataHandler.Send(ctx, &order.Detail{ Type: oType, Status: oStatus, Pair: pair, @@ -449,8 +447,7 @@ func (e *Exchange) processFuturesPrivateTradeOrders(respData []byte) error { OrderID: resp.TradeID, AssetType: asset.Futures, LastUpdated: resp.OrderTime.Time(), - } - return nil + }) } // processFuturesTransactionStatistics processes a futures transaction statistics @@ -464,36 +461,33 @@ func (e *Exchange) processFuturesTransactionStatistics(respData []byte, instrume } // processFuturesSystemAnnouncement processes a system announcement. -func (e *Exchange) processFuturesSystemAnnouncement(respData []byte, subject string) error { +func (e *Exchange) processFuturesSystemAnnouncement(ctx context.Context, respData []byte, subject string) error { resp := WsFuturesFundingBegin{} if err := json.Unmarshal(respData, &resp); err != nil { return err } resp.Subject = subject - e.Websocket.DataHandler <- &resp - return nil + return e.Websocket.DataHandler.Send(ctx, &resp) } // processFuturesFundingData processes a futures account funding data. -func (e *Exchange) processFuturesFundingData(respData []byte, instrument string) error { +func (e *Exchange) processFuturesFundingData(ctx context.Context, respData []byte, instrument string) error { resp := WsFundingRate{} if err := json.Unmarshal(respData, &resp); err != nil { return err } resp.Symbol = instrument - e.Websocket.DataHandler <- &resp - return nil + return e.Websocket.DataHandler.Send(ctx, &resp) } // processFuturesMarkPriceAndIndexPrice processes a futures account mark price and index price changes. -func (e *Exchange) processFuturesMarkPriceAndIndexPrice(respData []byte, instrument string) error { +func (e *Exchange) processFuturesMarkPriceAndIndexPrice(ctx context.Context, respData []byte, instrument string) error { resp := WsFuturesMarkPriceAndIndexPrice{} if err := json.Unmarshal(respData, &resp); err != nil { return err } resp.Symbol = instrument - e.Websocket.DataHandler <- &resp - return nil + return e.Websocket.DataHandler.Send(ctx, &resp) } // ensureFuturesOrderbookSnapshotLoaded makes sure an initial futures orderbook snapshot is loaded @@ -569,7 +563,7 @@ func (e *Exchange) processFuturesOrderbookLevel2(ctx context.Context, respData [ } // processFuturesTickerV2 processes a futures account ticker data. -func (e *Exchange) processFuturesTickerV2(respData []byte) error { +func (e *Exchange) processFuturesTickerV2(ctx context.Context, respData []byte) error { resp := WsFuturesTicker{} if err := json.Unmarshal(respData, &resp); err != nil { return err @@ -582,7 +576,7 @@ func (e *Exchange) processFuturesTickerV2(respData []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- &ticker.Price{ + return e.Websocket.DataHandler.Send(ctx, &ticker.Price{ AssetType: asset.Futures, Last: resp.FilledPrice.Float64(), Volume: resp.FilledSize.Float64(), @@ -593,12 +587,11 @@ func (e *Exchange) processFuturesTickerV2(respData []byte) error { Bid: resp.BestBidPrice.Float64(), AskSize: resp.BestAskSize.Float64(), BidSize: resp.BestBidSize.Float64(), - } - return nil + }) } // processFuturesKline represents a futures instrument kline data update. -func (e *Exchange) processFuturesKline(respData []byte, intervalStr string) error { +func (e *Exchange) processFuturesKline(ctx context.Context, respData []byte, intervalStr string) error { resp := WsFuturesKline{} err := json.Unmarshal(respData, &resp) if err != nil { @@ -609,7 +602,7 @@ func (e *Exchange) processFuturesKline(respData []byte, intervalStr string) erro if err != nil { return err } - e.Websocket.DataHandler <- &websocket.KlineData{ + return e.Websocket.DataHandler.Send(ctx, &websocket.KlineData{ Timestamp: resp.Time.Time(), AssetType: asset.Futures, Exchange: e.Name, @@ -621,12 +614,11 @@ func (e *Exchange) processFuturesKline(respData []byte, intervalStr string) erro LowPrice: resp.Candles[4].Float64(), Volume: resp.Candles[6].Float64(), Pair: pair, - } - return nil + }) } // processStopOrderEvent represents a stop order update event. -func (e *Exchange) processStopOrderEvent(respData []byte) error { +func (e *Exchange) processStopOrderEvent(ctx context.Context, respData []byte) error { resp := WsStopOrder{} err := json.Unmarshal(respData, &resp) if err != nil { @@ -645,7 +637,7 @@ func (e *Exchange) processStopOrderEvent(respData []byte) error { if err != nil { return err } - e.Websocket.DataHandler <- &order.Detail{ + return e.Websocket.DataHandler.Send(ctx, &order.Detail{ Price: resp.OrderPrice, TriggerPrice: resp.StopPrice, Amount: resp.Size, @@ -657,18 +649,16 @@ func (e *Exchange) processStopOrderEvent(respData []byte) error { Date: resp.CreatedAt.Time(), LastUpdated: resp.Timestamp.Time(), Pair: pair, - } - return nil + }) } // processMarginLendingTradeOrderEvent represents a margin lending trade order event. -func (e *Exchange) processMarginLendingTradeOrderEvent(respData []byte) error { +func (e *Exchange) processMarginLendingTradeOrderEvent(ctx context.Context, respData []byte) error { resp := WsMarginTradeOrderEntersEvent{} if err := json.Unmarshal(respData, &resp); err != nil { return err } - e.Websocket.DataHandler <- resp - return nil + return e.Websocket.DataHandler.Send(ctx, &resp) } // processAccountBalanceChange processes an account balance change @@ -687,12 +677,11 @@ func (e *Exchange) processAccountBalanceChange(ctx context.Context, respData []b if err := e.Accounts.Save(ctx, subAccts, false); err != nil { return err } - e.Websocket.DataHandler <- subAccts - return nil + return e.Websocket.DataHandler.Send(ctx, subAccts) } // processOrderChangeEvent processes order update events. -func (e *Exchange) processOrderChangeEvent(respData []byte, topic string) error { +func (e *Exchange) processOrderChangeEvent(ctx context.Context, respData []byte, topic string) error { response := WsTradeOrder{} err := json.Unmarshal(respData, &response) if err != nil { @@ -720,7 +709,7 @@ func (e *Exchange) processOrderChangeEvent(respData []byte, topic string) error return err } for x := range assets { - e.Websocket.DataHandler <- &order.Detail{ + if err := e.Websocket.DataHandler.Send(ctx, &order.Detail{ Price: response.Price, Amount: response.Size, ExecutedAmount: response.FilledSize, @@ -735,6 +724,8 @@ func (e *Exchange) processOrderChangeEvent(respData []byte, topic string) error Date: response.OrderTime.Time(), LastUpdated: response.Timestamp.Time(), Pair: pair, + }); err != nil { + return err } } return nil @@ -783,7 +774,7 @@ func (e *Exchange) processTradeData(respData []byte, instrument, topic string) e } // processTicker processes a ticker data for an instrument. -func (e *Exchange) processTicker(respData []byte, instrument, topic string) error { +func (e *Exchange) processTicker(ctx context.Context, respData []byte, instrument, topic string) error { response := WsTicker{} err := json.Unmarshal(respData, &response) if err != nil { @@ -801,7 +792,7 @@ func (e *Exchange) processTicker(respData []byte, instrument, topic string) erro if !e.AssetWebsocketSupport.IsAssetWebsocketSupported(assets[x]) { continue } - e.Websocket.DataHandler <- &ticker.Price{ + if err := e.Websocket.DataHandler.Send(ctx, &ticker.Price{ AssetType: assets[x], Last: response.Price, LastUpdated: response.Timestamp.Time(), @@ -812,13 +803,15 @@ func (e *Exchange) processTicker(respData []byte, instrument, topic string) erro AskSize: response.BestAskSize, BidSize: response.BestBidSize, Volume: response.Size, + }); err != nil { + return err } } return nil } // processCandlesticks processes a candlestick data for an instrument with a particular interval -func (e *Exchange) processCandlesticks(respData []byte, instrument, intervalString, topic string) error { +func (e *Exchange) processCandlesticks(ctx context.Context, respData []byte, instrument, intervalString, topic string) error { pair, err := currency.NewPairFromString(instrument) if err != nil { return err @@ -835,7 +828,7 @@ func (e *Exchange) processCandlesticks(respData []byte, instrument, intervalStri if !e.AssetWebsocketSupport.IsAssetWebsocketSupported(assets[x]) { continue } - e.Websocket.DataHandler <- &websocket.KlineData{ + if err := e.Websocket.DataHandler.Send(ctx, &websocket.KlineData{ Timestamp: resp.Time.Time(), Pair: pair, AssetType: assets[x], @@ -847,6 +840,8 @@ func (e *Exchange) processCandlesticks(respData []byte, instrument, intervalStri HighPrice: resp.Candles.HighPrice.Float64(), LowPrice: resp.Candles.LowPrice.Float64(), Volume: resp.Candles.TransactionVolume.Float64(), + }); err != nil { + return err } } return nil @@ -954,7 +949,7 @@ func (e *Exchange) processOrderbook(respData []byte, symbol, topic string) error } // processMarketSnapshot processes a price ticker information for a symbol. -func (e *Exchange) processMarketSnapshot(respData []byte, topic string) error { +func (e *Exchange) processMarketSnapshot(ctx context.Context, respData []byte, topic string) error { response := WsSnapshot{} err := json.Unmarshal(respData, &response) if err != nil { @@ -972,7 +967,7 @@ func (e *Exchange) processMarketSnapshot(respData []byte, topic string) error { if !e.AssetWebsocketSupport.IsAssetWebsocketSupported(assets[x]) { continue } - e.Websocket.DataHandler <- &ticker.Price{ + if err := e.Websocket.DataHandler.Send(ctx, &ticker.Price{ ExchangeName: e.Name, AssetType: assets[x], Last: response.Data.LastTradedPrice, @@ -984,6 +979,8 @@ func (e *Exchange) processMarketSnapshot(respData []byte, topic string) error { Open: response.Data.Open, Close: response.Data.Close, LastUpdated: response.Data.Datetime.Time(), + }); err != nil { + return err } } return nil diff --git a/exchanges/okx/okx_test.go b/exchanges/okx/okx_test.go index d89388296a4..f0d210b1ad8 100644 --- a/exchanges/okx/okx_test.go +++ b/exchanges/okx/okx_test.go @@ -3913,12 +3913,12 @@ func TestOrderPushData(t *testing.T) { e := new(Exchange) require.NoError(t, testexch.Setup(e), "Test instance Setup must not error") testexch.FixtureToDataHandler(t, "testdata/wsOrders.json", e.WsHandleData) - close(e.Websocket.DataHandler) - require.Len(t, e.Websocket.DataHandler, 4, "Should see 4 orders") - for resp := range e.Websocket.DataHandler { - switch v := resp.(type) { + e.Websocket.DataHandler.Close() + require.Len(t, e.Websocket.DataHandler.C, 4, "Should see 4 orders") + for resp := range e.Websocket.DataHandler.C { + switch v := resp.Data.(type) { case *order.Detail: - switch len(e.Websocket.DataHandler) { + switch len(e.Websocket.DataHandler.C) { case 3: assert.Equal(t, "452197707845865472", v.OrderID, "OrderID") assert.Equal(t, "HamsterParty14", v.ClientOrderID, "ClientOrderID") @@ -4099,13 +4099,13 @@ func TestWSProcessTrades(t *testing.T) { } total := len(assets) * len(exp) - require.Len(t, e.Websocket.DataHandler, total, "Must see correct number of trades") + require.Len(t, e.Websocket.DataHandler.C, total, "Must see correct number of trades") trades := make(map[asset.Item][]trade.Data) - for len(e.Websocket.DataHandler) > 0 { - resp := <-e.Websocket.DataHandler - switch v := resp.(type) { + for len(e.Websocket.DataHandler.C) > 0 { + resp := <-e.Websocket.DataHandler.C + switch v := resp.Data.(type) { case trade.Data: trades[v.AssetType] = append(trades[v.AssetType], v) case error: @@ -6121,8 +6121,8 @@ func TestBusinessWSCandleSubscriptions(t *testing.T) { var got currency.Pairs assert.Eventually(t, func() bool { select { - case a := <-e.Websocket.DataHandler: - switch v := a.(type) { + case a := <-e.Websocket.DataHandler.C: + switch v := a.Data.(type) { case websocket.KlineData: got = got.Add(v.Pair) case []CandlestickMarkPrice: @@ -6158,13 +6158,13 @@ func TestWsProcessPublicSpreadTrades(t *testing.T) { func TestWsProcessPublicSpreadTicker(t *testing.T) { t.Parallel() - err := e.wsProcessPublicSpreadTicker([]byte(okxSpreadPublicTickerJSON)) + err := e.wsProcessPublicSpreadTicker(t.Context(), []byte(okxSpreadPublicTickerJSON)) assert.NoError(t, err) } func TestWsProcessSpreadOrders(t *testing.T) { t.Parallel() - err := e.wsProcessSpreadOrders([]byte(wsProcessSpreadOrdersJSON)) + err := e.wsProcessSpreadOrders(t.Context(), []byte(wsProcessSpreadOrdersJSON)) assert.NoError(t, err) } diff --git a/exchanges/okx/okx_websocket.go b/exchanges/okx/okx_websocket.go index 45adf526ae9..028b2380208 100644 --- a/exchanges/okx/okx_websocket.go +++ b/exchanges/okx/okx_websocket.go @@ -336,7 +336,9 @@ func (e *Exchange) wsReadData(ctx context.Context, ws websocket.Connection) { return } if err := e.WsHandleData(ctx, resp.Raw); err != nil { - e.Websocket.DataHandler <- err + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } } } @@ -469,7 +471,7 @@ func (e *Exchange) WsHandleData(ctx context.Context, respRaw []byte) error { channelCandle3Mutc, channelCandle1Mutc, channelCandle1Wutc, channelCandle1Dutc, channelCandle2Dutc, channelCandle3Dutc, channelCandle5Dutc, channelCandle12Hutc, channelCandle6Hutc: - return e.wsProcessCandles(respRaw) + return e.wsProcessCandles(ctx, respRaw) case channelIndexCandle1Y, channelIndexCandle6M, channelIndexCandle3M, channelIndexCandle1M, channelIndexCandle1W, channelIndexCandle1D, channelIndexCandle2D, channelIndexCandle3D, channelIndexCandle5D, channelIndexCandle12H, channelIndexCandle6H, channelIndexCandle4H, @@ -478,78 +480,78 @@ func (e *Exchange) WsHandleData(ctx context.Context, respRaw []byte) error { channelIndexCandle3Mutc, channelIndexCandle1Mutc, channelIndexCandle1Wutc, channelIndexCandle1Dutc, channelIndexCandle2Dutc, channelIndexCandle3Dutc, channelIndexCandle5Dutc, channelIndexCandle12Hutc, channelIndexCandle6Hutc: - return e.wsProcessIndexCandles(respRaw) + return e.wsProcessIndexCandles(ctx, respRaw) case channelTickers: - return e.wsProcessTickers(respRaw) + return e.wsProcessTickers(ctx, respRaw) case channelIndexTickers: var response WsIndexTicker - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelStatus: var response WsSystemStatusResponse - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelPublicStrucBlockTrades: var response WsPublicTradesResponse - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelPublicBlockTrades: return e.wsProcessBlockPublicTrades(respRaw) case channelBlockTickers: var response WsBlockTicker - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelAccountGreeks: var response WsGreeks - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelAccount: var response WsAccountChannelPushData - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelPositions, channelLiquidationWarning: var response WsPositionResponse - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelBalanceAndPosition: return e.wsProcessBalanceAndPosition(ctx, respRaw) case channelOrders: - return e.wsProcessOrders(respRaw) + return e.wsProcessOrders(ctx, respRaw) case channelAlgoOrders: var response WsAlgoOrder - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelAlgoAdvance: var response WsAdvancedAlgoOrder - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelRFQs: var response WsRFQ - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelQuotes: var response WsQuote - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelStructureBlockTrades: var response WsStructureBlocTrade - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelSpotGridOrder: var response WsSpotGridAlgoOrder - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelGridOrdersContract: var response WsContractGridAlgoOrder - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelGridPositions: var response WsContractGridAlgoOrder - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelGridSubOrders: var response WsGridSubOrderData - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelInstruments: var response WSInstrumentResponse - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelOpenInterest: var response WSOpenInterestResponse - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelTrades, channelAllTrades: - return e.wsProcessTrades(respRaw) + return e.wsProcessTrades(ctx, respRaw) case channelEstimatedPrice: var response WsDeliveryEstimatedPrice - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelMarkPrice, channelPriceLimit: var response WsMarkPrice - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelOrderBooks5: return e.wsProcessOrderbook5(respRaw) case okxSpreadOrderbookLevel1, @@ -558,7 +560,7 @@ func (e *Exchange) WsHandleData(ctx context.Context, respRaw []byte) error { case okxSpreadPublicTrades: return e.wsProcessPublicSpreadTrades(respRaw) case okxSpreadPublicTicker: - return e.wsProcessPublicSpreadTicker(respRaw) + return e.wsProcessPublicSpreadTicker(ctx, respRaw) case channelOrderBooks, channelOrderBooks50TBT, channelBBOTBT, @@ -568,10 +570,10 @@ func (e *Exchange) WsHandleData(ctx context.Context, respRaw []byte) error { return e.wsProcessOptionTrades(respRaw) case channelOptSummary: var response WsOptionSummary - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelFundingRate: var response WsFundingRate - return e.wsProcessPushData(respRaw, &response) + return e.wsProcessPushData(ctx, respRaw, &response) case channelMarkPriceCandle1Y, channelMarkPriceCandle6M, channelMarkPriceCandle3M, channelMarkPriceCandle1M, channelMarkPriceCandle1W, channelMarkPriceCandle1D, channelMarkPriceCandle2D, channelMarkPriceCandle3D, channelMarkPriceCandle5D, channelMarkPriceCandle12H, channelMarkPriceCandle6H, channelMarkPriceCandle4H, @@ -580,9 +582,9 @@ func (e *Exchange) WsHandleData(ctx context.Context, respRaw []byte) error { channelMarkPriceCandle3Mutc, channelMarkPriceCandle1Mutc, channelMarkPriceCandle1Wutc, channelMarkPriceCandle1Dutc, channelMarkPriceCandle2Dutc, channelMarkPriceCandle3Dutc, channelMarkPriceCandle5Dutc, channelMarkPriceCandle12Hutc, channelMarkPriceCandle6Hutc: - return e.wsHandleMarkPriceCandles(respRaw) + return e.wsHandleMarkPriceCandles(ctx, respRaw) case okxSpreadOrders: - return e.wsProcessSpreadOrders(respRaw) + return e.wsProcessSpreadOrders(ctx, respRaw) case okxSpreadTrades: return e.wsProcessSpreadTrades(respRaw) case okxWithdrawalInfo: @@ -590,34 +592,33 @@ func (e *Exchange) WsHandleData(ctx context.Context, respRaw []byte) error { Arguments SubscriptionInfo `json:"arg"` Data []WsDepositInfo `json:"data"` }{} - return e.wsProcessPushData(respRaw, resp) + return e.wsProcessPushData(ctx, respRaw, resp) case okxDepositInfo: resp := &struct { Arguments SubscriptionInfo `json:"arg"` Data []WsWithdrawlInfo `json:"data"` }{} - return e.wsProcessPushData(respRaw, resp) + return e.wsProcessPushData(ctx, respRaw, resp) case channelRecurringBuy: resp := &struct { Arguments SubscriptionInfo `json:"arg"` Data []RecurringBuyOrder `json:"data"` }{} - return e.wsProcessPushData(respRaw, resp) + return e.wsProcessPushData(ctx, respRaw, resp) case liquidationOrders: var resp *LiquidationOrder - return e.wsProcessPushData(respRaw, &resp) + return e.wsProcessPushData(ctx, respRaw, &resp) case adlWarning: var resp ADLWarning - return e.wsProcessPushData(respRaw, &resp) + return e.wsProcessPushData(ctx, respRaw, &resp) case economicCalendar: var resp EconomicCalendarResponse - return e.wsProcessPushData(respRaw, &resp) + return e.wsProcessPushData(ctx, respRaw, &resp) case copyTrading: var resp CopyTradingNotification - return e.wsProcessPushData(respRaw, &resp) + return e.wsProcessPushData(ctx, respRaw, &resp) default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)} - return nil + return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{Message: e.Name + websocket.UnhandledMessage + string(respRaw)}) } } @@ -661,7 +662,7 @@ func (e *Exchange) wsProcessSpreadTrades(respRaw []byte) error { // wsProcessSpreadOrders retrieve order information from the sprd-order Websocket channel. // Data will not be pushed when first subscribed. // Data will only be pushed when triggered by events such as placing/canceling order. -func (e *Exchange) wsProcessSpreadOrders(respRaw []byte) error { +func (e *Exchange) wsProcessSpreadOrders(ctx context.Context, respRaw []byte) error { if respRaw == nil { return common.ErrNilPointer } @@ -714,12 +715,11 @@ func (e *Exchange) wsProcessSpreadOrders(respRaw []byte) error { LastUpdated: resp.Data[x].UpdateTime.Time(), } } - e.Websocket.DataHandler <- orderDetails - return nil + return e.Websocket.DataHandler.Send(ctx, orderDetails) } // wsProcessIndexCandles processes index candlestick data -func (e *Exchange) wsProcessIndexCandles(respRaw []byte) error { +func (e *Exchange) wsProcessIndexCandles(ctx context.Context, respRaw []byte) error { if respRaw == nil { return common.ErrNilPointer } @@ -763,14 +763,16 @@ func (e *Exchange) wsProcessIndexCandles(respRaw []byte) error { } for i := range assets { myCandle.AssetType = assets[i] - e.Websocket.DataHandler <- myCandle + if err := e.Websocket.DataHandler.Send(ctx, myCandle); err != nil { + return err + } } } return nil } // wsProcessPublicSpreadTicker process spread order ticker push data. -func (e *Exchange) wsProcessPublicSpreadTicker(respRaw []byte) error { +func (e *Exchange) wsProcessPublicSpreadTicker(ctx context.Context, respRaw []byte) error { var resp WsSpreadPushData data := []WsSpreadPublicTicker{} resp.Data = &data @@ -794,8 +796,7 @@ func (e *Exchange) wsProcessPublicSpreadTicker(respRaw []byte) error { LastUpdated: data[x].Timestamp.Time(), } } - e.Websocket.DataHandler <- tickers - return nil + return e.Websocket.DataHandler.Send(ctx, tickers) } // wsProcessPublicSpreadTrades retrieve the recent trades data from sprd-public-trades. @@ -991,7 +992,7 @@ func (e *Exchange) wsProcessOrderBooks(data []byte) error { }, }) if err != nil { - e.Websocket.DataHandler <- err + return err } } else { return err @@ -1132,7 +1133,7 @@ func (e *Exchange) CalculateOrderbookChecksum(orderbookData *WsOrderBookData) (u } // wsHandleMarkPriceCandles processes candlestick mark price push data as a result of subscription to "mark-price-candle*" channel. -func (e *Exchange) wsHandleMarkPriceCandles(data []byte) error { +func (e *Exchange) wsHandleMarkPriceCandles(ctx context.Context, data []byte) error { m := &struct { Argument SubscriptionInfo `json:"arg"` Data [][5]types.Number `json:"data"` @@ -1152,12 +1153,11 @@ func (e *Exchange) wsHandleMarkPriceCandles(data []byte) error { ClosePrice: m.Data[x][4].Float64(), } } - e.Websocket.DataHandler <- candles - return nil + return e.Websocket.DataHandler.Send(ctx, candles) } // wsProcessTrades handles a list of trade information. -func (e *Exchange) wsProcessTrades(data []byte) error { +func (e *Exchange) wsProcessTrades(ctx context.Context, data []byte) error { var response WsTradeOrder err := json.Unmarshal(data, &response) if err != nil { @@ -1204,7 +1204,9 @@ func (e *Exchange) wsProcessTrades(data []byte) error { } if tradeFeed { for i := range trades { - e.Websocket.DataHandler <- trades[i] + if err := e.Websocket.DataHandler.Send(ctx, trades[i]); err != nil { + return err + } } } if saveTradeData { @@ -1214,7 +1216,7 @@ func (e *Exchange) wsProcessTrades(data []byte) error { } // wsProcessOrders handles websocket order push data responses. -func (e *Exchange) wsProcessOrders(respRaw []byte) error { +func (e *Exchange) wsProcessOrders(ctx context.Context, respRaw []byte) error { var response WsOrderResponse err := json.Unmarshal(respRaw, &response) if err != nil { @@ -1227,19 +1229,11 @@ func (e *Exchange) wsProcessOrders(respRaw []byte) error { for x := range response.Data { orderType, err := order.StringToOrderType(response.Data[x].OrderType) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: response.Data[x].OrderID, - Err: err, - } + return err } orderStatus, err := order.StringToOrderStatus(response.Data[x].State) if err != nil { - e.Websocket.DataHandler <- order.ClassificationError{ - Exchange: e.Name, - OrderID: response.Data[x].OrderID, - Err: err, - } + return err } pair, err := currency.NewPairFromString(response.Data[x].InstrumentID) if err != nil { @@ -1300,13 +1294,15 @@ func (e *Exchange) wsProcessOrders(respRaw []byte) error { d.Amount = d.ExecutedAmount } } - e.Websocket.DataHandler <- d + if err := e.Websocket.DataHandler.Send(ctx, d); err != nil { + return err + } } return nil } // wsProcessCandles handler to get a list of candlestick messages. -func (e *Exchange) wsProcessCandles(respRaw []byte) error { +func (e *Exchange) wsProcessCandles(ctx context.Context, respRaw []byte) error { if respRaw == nil { return common.ErrNilPointer } @@ -1337,7 +1333,7 @@ func (e *Exchange) wsProcessCandles(respRaw []byte) error { candleInterval := strings.TrimPrefix(response.Argument.Channel, candle) for i := range response.Data { for j := range assets { - e.Websocket.DataHandler <- websocket.KlineData{ + if err := e.Websocket.DataHandler.Send(ctx, websocket.KlineData{ Timestamp: time.UnixMilli(response.Data[i][0].Int64()), Pair: response.Argument.InstrumentID, AssetType: assets[j], @@ -1348,6 +1344,8 @@ func (e *Exchange) wsProcessCandles(respRaw []byte) error { HighPrice: response.Data[i][2].Float64(), LowPrice: response.Data[i][3].Float64(), Volume: response.Data[i][5].Float64(), + }); err != nil { + return err } } } @@ -1355,7 +1353,7 @@ func (e *Exchange) wsProcessCandles(respRaw []byte) error { } // wsProcessTickers handles the trade ticker information. -func (e *Exchange) wsProcessTickers(data []byte) error { +func (e *Exchange) wsProcessTickers(ctx context.Context, data []byte) error { var response WSTickerResponse err := json.Unmarshal(data, &response) if err != nil { @@ -1401,7 +1399,9 @@ func (e *Exchange) wsProcessTickers(data []byte) error { Pair: response.Data[i].InstrumentID, LastUpdated: response.Data[i].TickerDataGenerationTime.Time(), } - e.Websocket.DataHandler <- tickData + if err := e.Websocket.DataHandler.Send(ctx, tickData); err != nil { + return err + } } } return nil @@ -1473,16 +1473,17 @@ func (e *Exchange) wsProcessBalanceAndPosition(ctx context.Context, data []byte) if err := e.Accounts.Save(ctx, subAccts, false); err != nil { return err } - e.Websocket.DataHandler <- subAccts - return nil + return e.Websocket.DataHandler.Send(ctx, subAccts) } // wsProcessPushData processes push data coming through the websocket channel -func (e *Exchange) wsProcessPushData(data []byte, resp any) error { +func (e *Exchange) wsProcessPushData(ctx context.Context, data []byte, resp any) error { if err := json.Unmarshal(data, resp); err != nil { return err } - e.Websocket.DataHandler <- resp + if err := e.Websocket.DataHandler.Send(ctx, resp); err != nil { + return err + } return nil } diff --git a/exchanges/poloniex/poloniex_test.go b/exchanges/poloniex/poloniex_test.go index eea1b856eb8..dc83c369e7e 100644 --- a/exchanges/poloniex/poloniex_test.go +++ b/exchanges/poloniex/poloniex_test.go @@ -562,7 +562,7 @@ func TestWsAuth(t *testing.T) { } timer := time.NewTimer(sharedtestvalues.WebsocketResponseDefaultTimeout) select { - case response := <-e.Websocket.DataHandler: + case response := <-e.Websocket.DataHandler.C: t.Error(response) case <-timer.C: } diff --git a/exchanges/poloniex/poloniex_websocket.go b/exchanges/poloniex/poloniex_websocket.go index 688d222cdfd..ab22192e1d4 100644 --- a/exchanges/poloniex/poloniex_websocket.go +++ b/exchanges/poloniex/poloniex_websocket.go @@ -110,9 +110,10 @@ func (e *Exchange) wsReadData(ctx context.Context) { if resp.Raw == nil { return } - err := e.wsHandleData(ctx, resp.Raw) - if err != nil { - e.Websocket.DataHandler <- fmt.Errorf("%s: %w", e.Name, err) + if err := e.wsHandleData(ctx, resp.Raw); err != nil { + if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { + log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) + } } } } @@ -170,17 +171,17 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { switch updateType { case accountNotificationPendingOrder: - err = e.processAccountPendingOrder(notification) + err = e.processAccountPendingOrder(ctx, notification) if err != nil { return fmt.Errorf("account notification pending order: %w", err) } case accountNotificationOrderUpdate: - err = e.processAccountOrderUpdate(notification) + err = e.processAccountOrderUpdate(ctx, notification) if err != nil { return fmt.Errorf("account notification order update: %w", err) } case accountNotificationOrderLimitCreated: - err = e.processAccountOrderLimit(notification) + err = e.processAccountOrderLimit(ctx, notification) if err != nil { return fmt.Errorf("account notification limit order creation: %w", err) } @@ -190,17 +191,17 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { return fmt.Errorf("account notification balance update: %w", err) } case accountNotificationTrades: - err = e.processAccountTrades(notification) + err = e.processAccountTrades(ctx, notification) if err != nil { return fmt.Errorf("account notification trades: %w", err) } case accountNotificationKilledOrder: - err = e.processAccountKilledOrder(notification) + err = e.processAccountKilledOrder(ctx, notification) if err != nil { return fmt.Errorf("account notification killed order: %w", err) } case accountNotificationMarginPosition: - err = e.processAccountMarginPosition(notification) + err = e.processAccountMarginPosition(ctx, notification) if err != nil { return fmt.Errorf("account notification margin position: %w", err) } @@ -210,7 +211,7 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { } return nil case wsTickerDataID: - err = e.wsHandleTickerData(data) + err = e.wsHandleTickerData(ctx, data) if err != nil { return fmt.Errorf("websocket ticker process: %w", err) } @@ -264,15 +265,17 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { return fmt.Errorf("websocket process trades update: %w", err) } default: - e.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + if err := e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{ Message: e.Name + websocket.UnhandledMessage + string(respRaw), + }); err != nil { + return err } } } return nil } -func (e *Exchange) wsHandleTickerData(data []any) error { +func (e *Exchange) wsHandleTickerData(ctx context.Context, data []any) error { tickerData, ok := data[2].([]any) if !ok { return fmt.Errorf("%w ticker data is not []any", @@ -360,7 +363,7 @@ func (e *Exchange) wsHandleTickerData(data []any) error { // highestTradeIn24Hm, ok := tickerData[8].(string) // lowestTradePrice24H, ok := tickerData[9].(string) - e.Websocket.DataHandler <- &ticker.Price{ + return e.Websocket.DataHandler.Send(ctx, &ticker.Price{ ExchangeName: e.Name, Volume: baseCurrencyVolume24H, QuoteVolume: quoteCurrencyVolume24H, @@ -371,8 +374,7 @@ func (e *Exchange) wsHandleTickerData(data []any) error { Last: lastPrice, AssetType: asset.Spot, Pair: pair, - } - return nil + }) } // WsProcessOrderbookSnapshot processes a new orderbook snapshot into a local @@ -641,7 +643,7 @@ func (e *Exchange) wsSendAuthorisedCommand(ctx context.Context, secret, key stri return e.Websocket.Conn.SendJSONMessage(ctx, request.Unset, req) } -func (e *Exchange) processAccountMarginPosition(notification []any) error { +func (e *Exchange) processAccountMarginPosition(ctx context.Context, notification []any) error { if len(notification) < 5 { return errNotEnoughData } @@ -674,7 +676,7 @@ func (e *Exchange) processAccountMarginPosition(notification []any) error { clientOrderID, _ := notification[4].(string) // Temp struct for margin position changes - e.Websocket.DataHandler <- struct { + return e.Websocket.DataHandler.Send(ctx, struct { OrderID string Code currency.Code Amount float64 @@ -684,12 +686,10 @@ func (e *Exchange) processAccountMarginPosition(notification []any) error { Code: code, Amount: amount, ClientOrderID: clientOrderID, - } - - return nil + }) } -func (e *Exchange) processAccountPendingOrder(notification []any) error { +func (e *Exchange) processAccountPendingOrder(ctx context.Context, notification []any) error { if len(notification) < 7 { return errNotEnoughData } @@ -742,7 +742,7 @@ func (e *Exchange) processAccountPendingOrder(notification []any) error { // null returned so ok check is not needed clientOrderID, _ := notification[6].(string) - e.Websocket.DataHandler <- &order.Detail{ + return e.Websocket.DataHandler.Send(ctx, &order.Detail{ Exchange: e.Name, OrderID: strconv.FormatFloat(orderID, 'f', -1, 64), Pair: pair, @@ -753,11 +753,10 @@ func (e *Exchange) processAccountPendingOrder(notification []any) error { RemainingAmount: orderAmount, ClientOrderID: clientOrderID, Status: order.Pending, - } - return nil + }) } -func (e *Exchange) processAccountOrderUpdate(notification []any) error { +func (e *Exchange) processAccountOrderUpdate(ctx context.Context, notification []any) error { if len(notification) < 5 { return errNotEnoughData } @@ -813,7 +812,7 @@ func (e *Exchange) processAccountOrderUpdate(notification []any) error { // null returned so ok check is not needed clientOrderID, _ := notification[4].(string) - e.Websocket.DataHandler <- &order.Detail{ + return e.Websocket.DataHandler.Send(ctx, &order.Detail{ Exchange: e.Name, RemainingAmount: cancelledAmount, Amount: amount + cancelledAmount, @@ -823,11 +822,10 @@ func (e *Exchange) processAccountOrderUpdate(notification []any) error { Status: oStatus, AssetType: asset.Spot, ClientOrderID: clientOrderID, - } - return nil + }) } -func (e *Exchange) processAccountOrderLimit(notification []any) error { +func (e *Exchange) processAccountOrderLimit(ctx context.Context, notification []any) error { if len(notification) != 9 { return errNotEnoughData } @@ -900,7 +898,7 @@ func (e *Exchange) processAccountOrderLimit(notification []any) error { // null returned so ok check is not needed clientOrderID, _ := notification[8].(string) - e.Websocket.DataHandler <- &order.Detail{ + return e.Websocket.DataHandler.Send(ctx, &order.Detail{ Exchange: e.Name, Price: orderPrice, RemainingAmount: orderAmount, @@ -914,8 +912,7 @@ func (e *Exchange) processAccountOrderLimit(notification []any) error { Date: timeParse, Pair: pair, ClientOrderID: clientOrderID, - } - return nil + }) } func (e *Exchange) processAccountBalanceUpdate(ctx context.Context, notification []any) error { @@ -959,8 +956,7 @@ func (e *Exchange) processAccountBalanceUpdate(ctx context.Context, notification if err := e.Accounts.Save(ctx, subAccts, true); err != nil { return err } - e.Websocket.DataHandler <- subAccts - return nil + return e.Websocket.DataHandler.Send(ctx, subAccts) } func deriveWalletType(s string) string { @@ -976,7 +972,7 @@ func deriveWalletType(s string) string { } } -func (e *Exchange) processAccountTrades(notification []any) error { +func (e *Exchange) processAccountTrades(ctx context.Context, notification []any) error { if len(notification) < 11 { return errNotEnoughData } @@ -1043,7 +1039,7 @@ func (e *Exchange) processAccountTrades(notification []any) error { return err } - e.Websocket.DataHandler <- &order.Detail{ + return e.Websocket.DataHandler.Send(ctx, &order.Detail{ Exchange: e.Name, OrderID: strconv.FormatFloat(orderID, 'f', -1, 64), Fee: totalFee, @@ -1058,11 +1054,10 @@ func (e *Exchange) processAccountTrades(notification []any) error { }}, AssetType: asset.Spot, ClientOrderID: clientOrderID, - } - return nil + }) } -func (e *Exchange) processAccountKilledOrder(notification []any) error { +func (e *Exchange) processAccountKilledOrder(ctx context.Context, notification []any) error { if len(notification) < 3 { return errNotEnoughData } @@ -1075,14 +1070,13 @@ func (e *Exchange) processAccountKilledOrder(notification []any) error { // null returned so ok check is not needed clientOrderID, _ := notification[2].(string) - e.Websocket.DataHandler <- &order.Detail{ + return e.Websocket.DataHandler.Send(ctx, &order.Detail{ Exchange: e.Name, OrderID: strconv.FormatFloat(orderID, 'f', -1, 64), Status: order.Cancelled, AssetType: asset.Spot, ClientOrderID: clientOrderID, - } - return nil + }) } func (e *Exchange) processTrades(currencyID float64, subData []any) error { diff --git a/exchanges/sharedtestvalues/sharedtestvalues.go b/exchanges/sharedtestvalues/sharedtestvalues.go index 7bb13ce37cd..26466b9b573 100644 --- a/exchanges/sharedtestvalues/sharedtestvalues.go +++ b/exchanges/sharedtestvalues/sharedtestvalues.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchange/stream" "github.com/thrasher-corp/gocryptotrader/exchange/websocket" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" @@ -53,8 +54,7 @@ func GetWebsocketStructChannelOverride() chan struct{} { // NewTestWebsocket returns a test websocket object func NewTestWebsocket() *websocket.Manager { w := websocket.NewManager() - w.DataHandler = make(chan any, WebsocketChannelOverrideCapacity) - w.ToRoutine = make(chan any, 1000) + w.DataHandler = stream.NewRelay(WebsocketChannelOverrideCapacity) return w } @@ -76,16 +76,16 @@ func SkipTestIfCredentialsUnset(t *testing.T, exch exchange.IBotExchange, canMan return } - message := []string{warningSkip} + out := []string{warningSkip} if !areTestAPICredentialsSet { - message = append(message, warningKeys) + out = append(out, warningKeys) } if supportsManipulatingOrders && !allowedToManipulateOrders { - message = append(message, warningManipulateOrders) + out = append(out, warningManipulateOrders) } - message = append(message, warningHowTo) - t.Skip(strings.Join(message, ", ")) + out = append(out, warningHowTo) + t.Skip(strings.Join(out, ", ")) } // SkipTestIfCannotManipulateOrders will only skip if the credentials are set diff --git a/exchanges/trade/trade.go b/exchanges/trade/trade.go index 6b3682d58c8..12a7fa261e7 100644 --- a/exchanges/trade/trade.go +++ b/exchanges/trade/trade.go @@ -1,6 +1,7 @@ package trade import ( + "context" "errors" "fmt" "slices" @@ -14,6 +15,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/database" tradesql "github.com/thrasher-corp/gocryptotrader/database/repository/trade" + "github.com/thrasher-corp/gocryptotrader/exchange/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" @@ -30,7 +32,7 @@ func (p *Processor) setup(wg *sync.WaitGroup) { // Setup configures necessary fields to the `Trade` structure that govern trade data // processing. -func (t *Trade) Setup(tradeFeedEnabled bool, c chan any) { +func (t *Trade) Setup(tradeFeedEnabled bool, c *stream.Relay) { t.dataHandler = c t.tradeFeedEnabled = tradeFeedEnabled } @@ -38,13 +40,16 @@ func (t *Trade) Setup(tradeFeedEnabled bool, c chan any) { // Update processes trade data, either by saving it or routing it through // the data channel. func (t *Trade) Update(save bool, data ...Data) error { + ctx := context.TODO() if len(data) == 0 { // nothing to do return nil } if t.tradeFeedEnabled { - t.dataHandler <- data + if err := t.dataHandler.Send(ctx, data); err != nil { + return err + } } if save { diff --git a/exchanges/trade/trade_types.go b/exchanges/trade/trade_types.go index a2d0a18460f..3a6aa344082 100644 --- a/exchanges/trade/trade_types.go +++ b/exchanges/trade/trade_types.go @@ -7,6 +7,7 @@ import ( "github.com/gofrs/uuid" "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchange/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" ) @@ -27,7 +28,7 @@ var ( // Trade used to hold data and methods related to trade dissemination and // storage type Trade struct { - dataHandler chan any + dataHandler *stream.Relay tradeFeedEnabled bool } diff --git a/internal/testing/exchange/exchange.go b/internal/testing/exchange/exchange.go index da2932538fc..ced9fdc2332 100644 --- a/internal/testing/exchange/exchange.go +++ b/internal/testing/exchange/exchange.go @@ -124,7 +124,7 @@ func MockWsInstance[T any, PT interface { // Exchanges which don't support subscription conf; Can be removed when all exchanges support sub conf b.Websocket.GenerateSubs = func() (subscription.List, error) { return subscription.List{}, nil } - err = b.Websocket.Connect() + err = b.Websocket.Connect(context.TODO()) require.NoError(tb, err, "Connect must not error") return e @@ -203,7 +203,7 @@ func SetupWs(tb testing.TB, e exchange.IBotExchange) { // Exchanges which don't support subscription conf; Can be removed when all exchanges support sub conf w.GenerateSubs = func() (subscription.List, error) { return subscription.List{}, nil } - err = w.Connect() + err = w.Connect(context.TODO()) require.NoError(tb, err, "Connect must not error") setupWsOnce[e] = true