diff --git a/contrib/spellcheck/ignore_words.txt b/contrib/spellcheck/ignore_words.txt index 8d9dba32c4a..5e2437afa5e 100644 --- a/contrib/spellcheck/ignore_words.txt +++ b/contrib/spellcheck/ignore_words.txt @@ -4,4 +4,5 @@ prevend flate zar insid -totalin \ No newline at end of file +totalin +bu \ No newline at end of file diff --git a/exchange/websocket/manager.go b/exchange/websocket/manager.go index c8b01f54387..990fc7c8d0d 100644 --- a/exchange/websocket/manager.go +++ b/exchange/websocket/manager.go @@ -495,7 +495,6 @@ func (m *Manager) connect(ctx context.Context) error { } if len(subs) == 0 { - // If no subscriptions are generated, we skip the connection if m.verbose { log.Warnf(log.WebsocketMgr, "%s websocket: no subscriptions generated", m.exchangeName) } @@ -614,7 +613,6 @@ func (m *Manager) createConnectAndSubscribe(ctx context.Context, ws *websocket, } return nil } - if err := ws.setup.Subscriber(ctx, conn, subs); err != nil { return fmt.Errorf("%w: %w", ErrSubscriptionFailure, err) } @@ -795,7 +793,6 @@ func (m *Manager) SetWebsocketURL(u string, auth, reconnect bool) error { if defaultVals { u = m.defaultURLAuth } - err := checkWebsocketURL(u) if err != nil { return err diff --git a/exchanges/bitfinex/bitfinex_test.go b/exchanges/bitfinex/bitfinex_test.go index 37aa1880adc..384c111a715 100644 --- a/exchanges/bitfinex/bitfinex_test.go +++ b/exchanges/bitfinex/bitfinex_test.go @@ -2,6 +2,7 @@ package bitfinex import ( "bufio" + "context" "log" "os" "strconv" @@ -508,7 +509,7 @@ func TestUpdateTicker(t *testing.T) { t.Parallel() _, err := e.UpdateTicker(t.Context(), btcusdPair, asset.Spot) - assert.NoError(t, common.ExcludeError(err, ticker.ErrBidEqualsAsk), "UpdateTicker may only error about locked markets") + assert.NoError(t, common.ExcludeError(err, ticker.ErrBidEqualsAsk), "UpdateTicker should only error about locked markets") } func TestUpdateTickers(t *testing.T) { @@ -541,7 +542,7 @@ func TestUpdateTickers(t *testing.T) { } } if !assert.Greaterf(t, okay/float64(len(avail))*100.0, acceptableThreshold, "At least %.f%% of %s tickers should not error", acceptableThreshold, a) { - assert.NoError(t, errs, "Collection of all the ticker errors") + assert.NoError(t, errs, "ticker error collection should be empty") } } } @@ -1260,7 +1261,7 @@ func TestWSSubscribe(t *testing.T) { e := new(Exchange) require.NoError(t, testexch.Setup(e), "TestInstance must not error") testexch.SetupWs(t, e) - err := e.Subscribe(subscription.List{{Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewBTCUSD()}, Asset: asset.Spot}}) + err := subscribeForTest(t.Context(), e, subscription.List{{Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewBTCUSD()}, Asset: asset.Spot}}) require.NoError(t, err, "Subscribe must not error") catcher := func() (ok bool) { i := <-e.Websocket.DataHandler.C @@ -1271,26 +1272,31 @@ func TestWSSubscribe(t *testing.T) { subs, err := e.GetSubscriptions() require.NoError(t, err, "GetSubscriptions must not error") - require.Len(t, subs, 1, "We must only have 1 subscription; subID subscription must have been Removed by subscribeToChan") + tickerSubs := make(subscription.List, 0, len(subs)) + for i := range subs { + if subs[i].Channel == subscription.TickerChannel && + subs[i].Asset == asset.Spot && + len(subs[i].Pairs) == 1 && + subs[i].Pairs[0].Equal(currency.NewBTCUSD()) { + tickerSubs = append(tickerSubs, subs[i]) + } + } + require.NotEmpty(t, tickerSubs, "there must be at least one BTC/USD ticker subscription") - err = e.Subscribe(subscription.List{{Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewBTCUSD()}, Asset: asset.Spot}}) + err = subscribeForTest(t.Context(), e, subscription.List{{Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewBTCUSD()}, Asset: asset.Spot}}) require.ErrorContains(t, err, "subscribe: dup (code: 10301)", "Duplicate subscription must error correctly") - subs, err = e.GetSubscriptions() - require.NoError(t, err, "GetSubscriptions must not error") - require.Len(t, subs, 1, "We must only have one subscription after an error attempt") - - err = e.Unsubscribe(subs) + err = unsubscribeForTest(t.Context(), e, subscription.List{tickerSubs[0]}) assert.NoError(t, err, "Unsubscribing should not error") - chanID, ok := subs[0].Key.(int) + chanID, ok := tickerSubs[0].Key.(int) assert.True(t, ok, "sub.Key should be an int") - err = e.Unsubscribe(subs) + err = unsubscribeForTest(t.Context(), e, subscription.List{tickerSubs[0]}) assert.ErrorContains(t, err, strconv.Itoa(chanID), "Unsubscribe should contain correct chanId") assert.ErrorContains(t, err, "unsubscribe: invalid (code: 10400)", "Unsubscribe should contain correct upstream error") - err = e.Subscribe(subscription.List{{ + err = subscribeForTest(t.Context(), e, subscription.List{{ Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewBTCUSD()}, Asset: asset.Spot, @@ -1299,6 +1305,51 @@ func TestWSSubscribe(t *testing.T) { assert.ErrorIs(t, err, errParamNotAllowed, "Trying to use a 'key' param should error errParamNotAllowed") } +func subscribeForTest(ctx context.Context, e *Exchange, subs subscription.List) error { + var err error + subs, err = subs.ExpandTemplates(e) + if err != nil { + return err + } + if !e.Websocket.IsConnected() { + if err := e.Websocket.Connect(ctx); err != nil { + return err + } + } + conn, err := e.Websocket.GetConnection(publicBitfinexWebsocketEndpoint) + if err != nil { + wsRunningURL, urlErr := e.API.Endpoints.GetURL(exchange.WebsocketSpot) + if urlErr != nil { + return err + } + conn, err = e.Websocket.GetConnection(wsRunningURL) + if err != nil { + return err + } + } + return e.subscribeForConnection(ctx, conn, subs) +} + +func unsubscribeForTest(ctx context.Context, e *Exchange, subs subscription.List) error { + if !e.Websocket.IsConnected() { + if err := e.Websocket.Connect(ctx); err != nil { + return err + } + } + conn, err := e.Websocket.GetConnection(publicBitfinexWebsocketEndpoint) + if err != nil { + wsRunningURL, urlErr := e.API.Endpoints.GetURL(exchange.WebsocketSpot) + if urlErr != nil { + return err + } + conn, err = e.Websocket.GetConnection(wsRunningURL) + if err != nil { + return err + } + } + return e.unsubscribeForConnection(ctx, conn, subs) +} + // TestSubToMap tests the channel to request map marshalling func TestSubToMap(t *testing.T) { s := &subscription.Subscription{ @@ -1399,18 +1450,22 @@ func TestWSCancelOffer(t *testing.T) { } func TestWSSubscribedResponse(t *testing.T) { + t.Parallel() + e := new(Exchange) + require.NoError(t, testexch.Setup(e), "Test instance Setup must not error") + conn := testexch.GetMockConn(t, e, publicBitfinexWebsocketEndpoint) ch, err := e.Websocket.Match.Set("subscribe:waiter1", 1) assert.NoError(t, err, "Setting a matcher should not error") - err = e.wsHandleData(t.Context(), []byte(`{"event":"subscribed","channel":"ticker","chanId":224555,"subId":"waiter1","symbol":"tBTCUSD","pair":"BTCUSD"}`)) + err = e.wsHandleData(t.Context(), conn, []byte(`{"event":"subscribed","channel":"ticker","chanId":224555,"subId":"waiter1","symbol":"tBTCUSD","pair":"BTCUSD"}`)) if assert.Error(t, err, "Should error if sub is not registered yet") { assert.ErrorIs(t, err, websocket.ErrSubscriptionFailure, "Should error SubFailure if sub isn't registered yet") assert.ErrorIs(t, err, subscription.ErrNotFound, "Should error SubNotFound if sub isn't registered yet") assert.ErrorContains(t, err, "waiter1", "Should error containing subID if") } - err = e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Key: "waiter1"}) + err = e.Websocket.AddSubscriptions(conn, &subscription.Subscription{Key: "waiter1"}) require.NoError(t, err, "AddSubscriptions must not error") - err = e.wsHandleData(t.Context(), []byte(`{"event":"subscribed","channel":"ticker","chanId":224555,"subId":"waiter1","symbol":"tBTCUSD","pair":"BTCUSD"}`)) + err = e.wsHandleData(t.Context(), conn, []byte(`{"event":"subscribed","channel":"ticker","chanId":224555,"subId":"waiter1","symbol":"tBTCUSD","pair":"BTCUSD"}`)) assert.NoError(t, err, "wsHandleData should not error") if assert.NotEmpty(t, ch, "Matcher should have received a sub notification") { msg := <-ch @@ -1421,20 +1476,21 @@ func TestWSSubscribedResponse(t *testing.T) { } func TestWSOrderBook(t *testing.T) { - err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Key: 23405, Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: subscription.OrderbookChannel}) + err := e.Websocket.AddSubscriptions(nil, &subscription.Subscription{Key: 23405, Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: subscription.OrderbookChannel}) require.NoError(t, err, "AddSubscriptions must not error") + conn := testexch.GetMockConn(t, e, "") pressXToJSON := `[23405,[[38334303613,9348.8,0.53],[38334308111,9348.8,5.98979404],[38331335157,9344.1,1.28965787],[38334302803,9343.8,0.08230094],[38334279092,9343,0.8],[38334307036,9342.938663676,0.8],[38332749107,9342.9,0.2],[38332277330,9342.8,0.85],[38329406786,9342,0.1432012],[38332841570,9341.947288638,0.3],[38332163238,9341.7,0.3],[38334303384,9341.6,0.324],[38332464840,9341.4,0.5],[38331935870,9341.2,0.5],[38334312082,9340.9,0.02126899],[38334261292,9340.8,0.26763],[38334138680,9340.625455254,0.12],[38333896802,9339.8,0.85],[38331627527,9338.9,1.57863959],[38334186713,9338.9,0.26769],[38334305819,9338.8,2.999],[38334211180,9338.75285796,3.999],[38334310699,9337.8,0.10679883],[38334307414,9337.5,1],[38334179822,9337.1,0.26773],[38334306600,9336.659955102,1.79],[38334299667,9336.6,1.1],[38334306452,9336.6,0.13979771],[38325672859,9336.3,1.25],[38334311646,9336.2,1],[38334258509,9336.1,0.37],[38334310592,9336,1.79],[38334310378,9335.6,1.43],[38334132444,9335.2,0.26777],[38331367325,9335,0.07],[38334310703,9335,0.10680562],[38334298209,9334.7,0.08757301],[38334304857,9334.456899462,0.291],[38334309940,9334.088390727,0.0725],[38334310377,9333.7,1.2868],[38334297615,9333.607784,0.1108],[38334095188,9333.3,0.26785],[38334228913,9332.7,0.40861186],[38334300526,9332.363996604,0.3884],[38334310701,9332.2,0.10680562],[38334303548,9332.005382871,0.07],[38334311798,9331.8,0.41285228],[38334301012,9331.7,1.7952],[38334089877,9331.4,0.2679],[38321942150,9331.2,0.2],[38334310670,9330,1.069],[38334063096,9329.6,0.26796],[38334310700,9329.4,0.10680562],[38334310404,9329.3,1],[38334281630,9329.1,6.57150597],[38334036864,9327.7,0.26801],[38334310702,9326.6,0.10680562],[38334311799,9326.1,0.50220625],[38334164163,9326,0.219638],[38334309722,9326,1.5],[38333051682,9325.8,0.26807],[38334302027,9325.7,0.75],[38334203435,9325.366592,0.32397696],[38321967613,9325,0.05],[38334298787,9324.9,0.3],[38334301719,9324.8,3.6227592],[38331316716,9324.763454646,0.71442],[38334310698,9323.8,0.10680562],[38334035499,9323.7,0.23431017],[38334223472,9322.670551788,0.42150603],[38334163459,9322.560399006,0.143967],[38321825171,9320.8,2],[38334075805,9320.467496148,0.30772633],[38334075800,9319.916732238,0.61457592],[38333682302,9319.7,0.0011],[38331323088,9319.116771762,0.12913],[38333677480,9319,0.0199],[38334277797,9318.6,0.89],[38325235155,9318.041088,1.20249],[38334310910,9317.82382938,1.79],[38334311811,9317.2,0.61079138],[38334311812,9317.2,0.71937652],[38333298214,9317.1,50],[38334306359,9317,1.79],[38325531545,9316.382823951,0.21263],[38333727253,9316.3,0.02316372],[38333298213,9316.1,45],[38333836479,9316,2.135],[38324520465,9315.9,2.7681],[38334307411,9315.5,1],[38330313617,9315.3,0.84455],[38334077770,9315.294024,0.01248397],[38334286663,9315.294024,1],[38325533762,9315.290315394,2.40498],[38334310018,9315.2,3],[38333682617,9314.6,0.0011],[38334304794,9314.6,0.76364676],[38334304798,9314.3,0.69242113],[38332915733,9313.8,0.0199],[38334084411,9312.8,1],[38334311893,9350.1,-1.015],[38334302734,9350.3,-0.26737],[38334300732,9350.8,-5.2],[38333957619,9351,-0.90677089],[38334300521,9351,-1.6457],[38334301600,9351.012829557,-0.0523],[38334308878,9351.7,-2.5],[38334299570,9351.921544,-0.1015],[38334279367,9352.1,-0.26732],[38334299569,9352.411802928,-0.4036],[38334202773,9353.4,-0.02139404],[38333918472,9353.7,-1.96412776],[38334278782,9354,-0.26731],[38334278606,9355,-1.2785],[38334302105,9355.439221251,-0.79191542],[38313897370,9355.569409242,-0.43363],[38334292995,9355.584296,-0.0979],[38334216989,9355.8,-0.03686414],[38333894025,9355.9,-0.26721],[38334293798,9355.936691952,-0.4311],[38331159479,9356,-0.4204022],[38333918888,9356.1,-1.10885563],[38334298205,9356.4,-0.20124428],[38328427481,9356.5,-0.1],[38333343289,9356.6,-0.41034213],[38334297205,9356.6,-0.08835018],[38334277927,9356.741101161,-0.0737],[38334311645,9356.8,-0.5],[38334309002,9356.9,-5],[38334309736,9357,-0.10680107],[38334306448,9357.4,-0.18645275],[38333693302,9357.7,-0.2672],[38332815159,9357.8,-0.0011],[38331239824,9358.2,-0.02],[38334271608,9358.3,-2.999],[38334311971,9358.4,-0.55],[38333919260,9358.5,-1.9972841],[38334265365,9358.5,-1.7841],[38334277960,9359,-3],[38334274601,9359.020969848,-3],[38326848839,9359.1,-0.84],[38334291080,9359.247048,-0.16199869],[38326848844,9359.4,-1.84],[38333680200,9359.6,-0.26713],[38331326606,9359.8,-0.84454],[38334309738,9359.8,-0.10680107],[38331314707,9359.9,-0.2],[38333919803,9360.9,-1.41177599],[38323651149,9361.33417827,-0.71442],[38333656906,9361.5,-0.26705],[38334035500,9361.5,-0.40861586],[38334091886,9362.4,-6.85940815],[38334269617,9362.5,-4],[38323629409,9362.545858872,-2.40497],[38334309737,9362.7,-0.10680107],[38334312380,9362.7,-3],[38325280830,9362.8,-1.75123],[38326622800,9362.8,-1.05145],[38333175230,9363,-0.0011],[38326848745,9363.2,-0.79],[38334308960,9363.206775564,-0.12],[38333920234,9363.3,-1.25318113],[38326848843,9363.4,-1.29],[38331239823,9363.4,-0.02],[38333209613,9363.4,-0.26719],[38334299964,9364,-0.05583123],[38323470224,9364.161816648,-0.12912],[38334284711,9365,-0.21346019],[38334299594,9365,-2.6757062],[38323211816,9365.073132585,-0.21262],[38334312456,9365.1,-0.11167861],[38333209612,9365.2,-0.26719],[38327770474,9365.3,-0.0073],[38334298788,9365.3,-0.3],[38334075803,9365.409831204,-0.30772637],[38334309740,9365.5,-0.10680107],[38326608767,9365.7,-2.76809],[38333920657,9365.7,-1.25848083],[38329594226,9366.6,-0.02587],[38334311813,9366.7,-4.72290945],[38316386301,9367.39258128,-2.37581],[38334302026,9367.4,-4.5],[38334228915,9367.9,-0.81725458],[38333921381,9368.1,-1.72213641],[38333175678,9368.2,-0.0011],[38334301150,9368.2,-2.654604],[38334297208,9368.3,-0.78036466],[38334309739,9368.3,-0.10680107],[38331227515,9368.7,-0.02],[38331184470,9369,-0.003975],[38334203436,9369.319616,-0.32397695],[38334269964,9369.7,-0.5],[38328386732,9370,-4.11759935],[38332719555,9370,-0.025],[38333921935,9370.5,-1.2224398],[38334258511,9370.5,-0.35],[38326848842,9370.8,-0.34],[38333985038,9370.9,-0.8551502],[38334283018,9370.9,-1],[38326848744,9371,-1.34]],5]` - err = e.wsHandleData(t.Context(), []byte(pressXToJSON)) + err = e.wsHandleData(t.Context(), conn, []byte(pressXToJSON)) if err != nil { t.Error(err) } pressXToJSON = `[23405,[7617,52.98726298,7617.1,53.601795929999994,-550.9,-0.0674,7617,8318.92961981,8257.8,7500],6]` - err = e.wsHandleData(t.Context(), []byte(pressXToJSON)) + err = e.wsHandleData(t.Context(), conn, []byte(pressXToJSON)) if err != nil { t.Error(err) } pressXToJSON = `[23405,[7617,52.98726298,7617.1,53.601795929999994,-550.9,-0.0674,7617,8318.92961981,8257.8,7500]]` - assert.NotPanics(t, func() { err = e.wsHandleData(t.Context(), []byte(pressXToJSON)) }, "handleWSBookUpdate should not panic when seqNo is not configured to be sent") + assert.NotPanics(t, func() { err = e.wsHandleData(t.Context(), conn, []byte(pressXToJSON)) }, "handleWSBookUpdate should not panic when seqNo is not configured to be sent") assert.ErrorIs(t, err, errNoSeqNo, "handleWSBookUpdate should send correct error") } @@ -1443,9 +1499,10 @@ func TestWSAllTrades(t *testing.T) { e := new(Exchange) require.NoError(t, testexch.Setup(e), "Test instance Setup must not error") - err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: subscription.AllTradesChannel, Key: 18788}) + err := e.Websocket.AddSubscriptions(nil, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: subscription.AllTradesChannel, Key: 18788}) require.NoError(t, err, "AddSubscriptions must not error") - testexch.FixtureToDataHandler(t, "testdata/wsAllTrades.json", e.wsHandleData) + conn := testexch.GetMockConn(t, e, "") + testexch.FixtureToDataHandler(t, "testdata/wsAllTrades.json", func(ctx context.Context, b []byte) error { return e.wsHandleData(ctx, conn, b) }) e.Websocket.DataHandler.Close() expJSON := []string{ `{"TID":"412685577","AssetType":"spot","Side":"BUY","Price":176.3,"Amount":11.1998,"Timestamp":"2020-01-29T03:27:24.802Z"}`, @@ -1476,10 +1533,10 @@ func TestWSAllTrades(t *testing.T) { } func TestWSTickerResponse(t *testing.T) { - err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: subscription.TickerChannel, Key: 11534}) + err := e.Websocket.AddSubscriptions(nil, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: subscription.TickerChannel, Key: 11534}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON := `[11534,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` - err = e.wsHandleData(t.Context(), []byte(pressXToJSON)) + err = e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(pressXToJSON)) if err != nil { t.Error(err) } @@ -1487,10 +1544,10 @@ func TestWSTickerResponse(t *testing.T) { if err != nil { t.Error(err) } - err = e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: subscription.TickerChannel, Key: 123412}) + err = e.Websocket.AddSubscriptions(nil, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: subscription.TickerChannel, Key: 123412}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON = `[123412,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` - err = e.wsHandleData(t.Context(), []byte(pressXToJSON)) + err = e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(pressXToJSON)) if err != nil { t.Error(err) } @@ -1498,10 +1555,10 @@ func TestWSTickerResponse(t *testing.T) { if err != nil { t.Error(err) } - err = e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: subscription.TickerChannel, Key: 123413}) + err = e.Websocket.AddSubscriptions(nil, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: subscription.TickerChannel, Key: 123413}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON = `[123413,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` - err = e.wsHandleData(t.Context(), []byte(pressXToJSON)) + err = e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(pressXToJSON)) if err != nil { t.Error(err) } @@ -1509,25 +1566,25 @@ func TestWSTickerResponse(t *testing.T) { if err != nil { t.Error(err) } - err = e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: subscription.TickerChannel, Key: 123414}) + err = e.Websocket.AddSubscriptions(nil, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: subscription.TickerChannel, Key: 123414}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON = `[123414,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` - err = e.wsHandleData(t.Context(), []byte(pressXToJSON)) + err = e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(pressXToJSON)) if err != nil { t.Error(err) } } func TestWSCandleResponse(t *testing.T) { - err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: subscription.CandlesChannel, Key: 343351}) + err := e.Websocket.AddSubscriptions(nil, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: subscription.CandlesChannel, Key: 343351}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON := `[343351,[[1574698260000,7379.785503,7383.8,7388.3,7379.785503,1.68829482]]]` - err = e.wsHandleData(t.Context(), []byte(pressXToJSON)) + err = e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(pressXToJSON)) if err != nil { t.Error(err) } pressXToJSON = `[343351,[1574698200000,7399.9,7379.7,7399.9,7371.8,41.63633658]]` - err = e.wsHandleData(t.Context(), []byte(pressXToJSON)) + err = e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(pressXToJSON)) if err != nil { t.Error(err) } @@ -1535,26 +1592,30 @@ func TestWSCandleResponse(t *testing.T) { func TestWSOrderSnapshot(t *testing.T) { pressXToJSON := `[0,"os",[[34930659963,null,1574955083558,"tETHUSD",1574955083558,1574955083573,0.201104,0.201104,"EXCHANGE LIMIT",null,null,null,0,"ACTIVE",null,null,120,0,0,0,null,null,null,0,0,null,null,null,"BFX",null,null,null]]]` - err := e.wsHandleData(t.Context(), []byte(pressXToJSON)) + err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(pressXToJSON)) if err != nil { t.Error(err) } pressXToJSON = `[0,"oc",[34930659963,null,1574955083558,"tETHUSD",1574955083558,1574955354487,0.201104,0.201104,"EXCHANGE LIMIT",null,null,null,0,"CANCELED",null,null,120,0,0,0,null,null,null,0,0,null,null,null,"BFX",null,null,null]]` - err = e.wsHandleData(t.Context(), []byte(pressXToJSON)) + err = e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(pressXToJSON)) if err != nil { t.Error(err) } } func TestWSNotifications(t *testing.T) { + t.Parallel() + e := new(Exchange) + require.NoError(t, testexch.Setup(e), "Test instance Setup must not error") + conn := testexch.GetMockConn(t, e, authenticatedBitfinexWebsocketEndpoint) pressXToJSON := `[0,"n",[1575282446099,"fon-req",null,null,[41238905,null,null,null,-1000,null,null,null,null,null,null,null,null,null,0.002,2,null,null,null,null,null],null,"SUCCESS","Submitting funding bid of 1000.0 USD at 0.2000 for 2 days."]]` - err := e.wsHandleData(t.Context(), []byte(pressXToJSON)) + err := e.wsHandleData(t.Context(), conn, []byte(pressXToJSON)) if err != nil { t.Error(err) } pressXToJSON = `[0,"n",[1575287438.515,"on-req",null,null,[1185815098,null,1575287436979,"tETHUSD",1575287438515,1575287438515,-2.5,-2.5,"LIMIT",null,null,null,0,"ACTIVE",null,null,230,0,0,0,null,null,null,0,null,null,null,null,"API>BFX",null,null,null],null,"SUCCESS","Submitting limit sell order for -2.5 ETH."]]` - err = e.wsHandleData(t.Context(), []byte(pressXToJSON)) + err = e.wsHandleData(t.Context(), conn, []byte(pressXToJSON)) if err != nil { t.Error(err) } @@ -1562,76 +1623,65 @@ func TestWSNotifications(t *testing.T) { func TestWSFundingOfferSnapshotAndUpdate(t *testing.T) { pressXToJSON := `[0,"fos",[[41237920,"fETH",1573912039000,1573912039000,0.5,0.5,"LIMIT",null,null,0,"ACTIVE",null,null,null,0.0024,2,0,0,null,0,null]]]` - if err := e.wsHandleData(t.Context(), []byte(pressXToJSON)); err != nil { + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(pressXToJSON)); err != nil { t.Error(err) } pressXToJSON = `[0,"fon",[41238747,"fUST",1575026670000,1575026670000,5000,5000,"LIMIT",null,null,0,"ACTIVE",null,null,null,0.006000000000000001,30,0,0,null,0,null]]` - if err := e.wsHandleData(t.Context(), []byte(pressXToJSON)); err != nil { + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(pressXToJSON)); err != nil { t.Error(err) } } func TestWSFundingCreditSnapshotAndUpdate(t *testing.T) { pressXToJSON := `[0,"fcs",[[26223578,"fUST",1,1575052261000,1575296187000,350,0,"ACTIVE",null,null,null,0,30,1575052261000,1575293487000,0,0,null,0,null,0,"tBTCUST"],[26223711,"fUSD",-1,1575291961000,1575296187000,180,0,"ACTIVE",null,null,null,0.002,7,1575282446000,1575295587000,0,0,null,0,null,0,"tETHUSD"]]]` - if err := e.wsHandleData(t.Context(), []byte(pressXToJSON)); err != nil { + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(pressXToJSON)); err != nil { t.Error(err) } pressXToJSON = `[0,"fcu",[26223578,"fUST",1,1575052261000,1575296787000,350,0,"ACTIVE",null,null,null,0,30,1575052261000,1575293487000,0,0,null,0,null,0,"tBTCUST"]]` - if err := e.wsHandleData(t.Context(), []byte(pressXToJSON)); err != nil { + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(pressXToJSON)); err != nil { t.Error(err) } } func TestWSFundingLoanSnapshotAndUpdate(t *testing.T) { pressXToJSON := `[0,"fls",[[2995442,"fUSD",-1,1575291961000,1575295850000,820,0,"ACTIVE",null,null,null,0.002,7,1575282446000,1575295850000,0,0,null,0,null,0]]]` - if err := e.wsHandleData(t.Context(), []byte(pressXToJSON)); err != nil { + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(pressXToJSON)); err != nil { t.Error(err) } pressXToJSON = `[0,"fln",[2995444,"fUSD",-1,1575298742000,1575298742000,1000,0,"ACTIVE",null,null,null,0.002,7,1575298742000,1575298742000,0,0,null,0,null,0]]` - if err := e.wsHandleData(t.Context(), []byte(pressXToJSON)); err != nil { - t.Error(err) - } -} - -func TestWSWalletSnapshot(t *testing.T) { - pressXToJSON := `[0,"ws",[["exchange","SAN",19.76,0,null,null,null]]]` - if err := e.wsHandleData(t.Context(), []byte(pressXToJSON)); err != nil { - t.Error(err) - } -} - -func TestWSBalanceUpdate(t *testing.T) { - const pressXToJSON = `[0,"bu",[4131.85,4131.85]]` - if err := e.wsHandleData(t.Context(), []byte(pressXToJSON)); err != nil { + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(pressXToJSON)); err != nil { t.Error(err) } } -func TestWSMarginInfoUpdate(t *testing.T) { - const pressXToJSON = `[0,"miu",["base",[-13.014640000000007,0,49331.70267297,49318.68803297,27]]]` - if err := e.wsHandleData(t.Context(), []byte(pressXToJSON)); err != nil { - t.Error(err) +func TestWSAccountAndFundingUpdates(t *testing.T) { + t.Parallel() + tests := map[string]string{ + "wallet snapshot": `[0,"ws",[["exchange","SAN",19.76,0,null,null,null]]]`, + "balance update": `[0,"bu",[4131.85,4131.85]]`, + "margin info update": `[0,"miu",["base",[-13.014640000000007,0,49331.70267297,49318.68803297,27]]]`, + "funding info": `[0,"fiu",["sym","tETHUSD",[149361.09689202666,149639.26293509,830.0182168075556,895.0658432466332]]]`, } -} -func TestWSFundingInfoUpdate(t *testing.T) { - const pressXToJSON = `[0,"fiu",["sym","tETHUSD",[149361.09689202666,149639.26293509,830.0182168075556,895.0658432466332]]]` - if err := e.wsHandleData(t.Context(), []byte(pressXToJSON)); err != nil { - t.Error(err) + for name, payload := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + require.NoError(t, e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(payload))) + }) } } func TestWSFundingTrade(t *testing.T) { pressXToJSON := `[0,"fte",[636854,"fUSD",1575282446000,41238905,-1000,0.002,7,null]]` - if err := e.wsHandleData(t.Context(), []byte(pressXToJSON)); err != nil { + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(pressXToJSON)); err != nil { t.Error(err) } pressXToJSON = `[0,"ftu",[636854,"fUSD",1575282446000,41238905,-1000,0.002,7,null]]` - if err := e.wsHandleData(t.Context(), []byte(pressXToJSON)); err != nil { + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(pressXToJSON)); err != nil { t.Error(err) } } @@ -2016,8 +2066,8 @@ func TestGetCurrencyTradeURL(t *testing.T) { testexch.UpdatePairsOnce(t, e) for _, a := range e.GetAssetTypes(false) { pairs, err := e.CurrencyPairs.GetPairs(a, false) - require.NoErrorf(t, err, "cannot get pairs for %s", a) - require.NotEmptyf(t, pairs, "no pairs for %s", a) + require.NoErrorf(t, err, "GetPairs must not error for asset %s", a) + require.NotEmptyf(t, pairs, "pairs must not be empty for asset %s", a) resp, err := e.GetCurrencyTradeURL(t.Context(), a, pairs[0]) require.NoError(t, err) assert.NotEmpty(t, resp) diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index c9e02546ffa..b479ade75a9 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -23,6 +23,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" "github.com/thrasher-corp/gocryptotrader/exchange/websocket" + exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" @@ -119,71 +120,21 @@ var subscriptionNames = map[string]string{ subscription.AllTradesChannel: wsTradesChannel, } -// WsConnect starts a new websocket connection -func (e *Exchange) WsConnect() error { - ctx := context.TODO() - if !e.Websocket.IsEnabled() || !e.IsEnabled() { - return websocket.ErrWebsocketNotEnabled - } - var dialer gws.Dialer - err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{}, nil) - if err != nil { - return fmt.Errorf("%v unable to connect to Websocket. Error: %s", - e.Name, - err) - } - - e.Websocket.Wg.Add(1) - go e.wsReadData(ctx, e.Websocket.Conn) - if e.Websocket.CanUseAuthenticatedEndpoints() { - err = e.Websocket.AuthConn.Dial(ctx, &dialer, http.Header{}, nil) - if err != nil { - log.Errorf(log.ExchangeSys, - "%v unable to connect to authenticated Websocket. Error: %s", - e.Name, - err) - e.Websocket.SetCanUseAuthenticatedEndpoints(false) - } - e.Websocket.Wg.Add(1) - go e.wsReadData(ctx, e.Websocket.AuthConn) - err = e.WsSendAuth(ctx) - if err != nil { - log.Errorf(log.ExchangeSys, - "%v - authentication failed: %v\n", - e.Name, - err) - e.Websocket.SetCanUseAuthenticatedEndpoints(false) - } - } - - e.Websocket.Wg.Add(1) - return e.ConfigureWS(ctx) -} - -// wsReadData receives and passes on websocket messages for processing -func (e *Exchange) wsReadData(ctx context.Context, ws websocket.Connection) { - defer e.Websocket.Wg.Done() - for { - resp := ws.ReadMessage() - if resp.Raw == nil { - return - } - if err := e.wsHandleData(ctx, resp.Raw); err != nil { - if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { - log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) - } - } +func (e *Exchange) wsConnect(ctx context.Context, conn websocket.Connection) error { + if err := conn.Dial(ctx, &gws.Dialer{}, http.Header{}, nil); err != nil { + return fmt.Errorf("%v unable to connect to Websocket. Error: %s", e.Name, err) } + return e.ConfigureWS(ctx, conn) } -func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { +func (e *Exchange) wsHandleData(ctx context.Context, conn websocket.Connection, respRaw []byte) error { var result any if err := json.Unmarshal(respRaw, &result); err != nil { return err } switch d := result.(type) { case map[string]any: - return e.handleWSEvent(ctx, respRaw) + return e.handleWSEvent(ctx, conn, respRaw) case []any: chanIDFloat, ok := d[0].(float64) if !ok { @@ -195,7 +146,7 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { if chanID != 0 { if s := e.Websocket.GetSubscription(chanID); s != nil { - return e.handleWSChannelUpdate(ctx, s, respRaw, eventType, d) + return e.handleWSChannelUpdate(ctx, conn, s, respRaw, eventType, d) } if e.Verbose { log.Warnf(log.ExchangeSys, "%s %s; dropped WS message: %s", e.Name, subscription.ErrNotFound, respRaw) @@ -214,7 +165,7 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { case wsHeartbeat, pong: return nil case wsNotification: - return e.handleWSNotification(ctx, d, respRaw) + return e.handleWSNotification(ctx, conn, d, respRaw) case wsOrderSnapshot: if snapBundle, ok := d[2].([]any); ok && len(snapBundle) > 0 { if _, ok := snapBundle[0].([]any); ok { @@ -476,31 +427,31 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { return nil } -func (e *Exchange) handleWSEvent(ctx context.Context, respRaw []byte) error { +func (e *Exchange) handleWSEvent(ctx context.Context, conn websocket.Connection, respRaw []byte) error { event, err := jsonparser.GetUnsafeString(respRaw, "event") if err != nil { return fmt.Errorf("%w 'event': %w from message: %s", common.ErrParsingWSField, err, respRaw) } switch event { case wsEventSubscribed: - return e.handleWSSubscribed(respRaw) + return e.handleWSSubscribed(conn, respRaw) case wsEventUnsubscribed: chanID, err := jsonparser.GetUnsafeString(respRaw, "chanId") if err != nil { return fmt.Errorf("%w 'chanId': %w from message: %s", common.ErrParsingWSField, err, respRaw) } - err = e.Websocket.Match.RequireMatchWithData("unsubscribe:"+chanID, respRaw) + err = conn.RequireMatchWithData("unsubscribe:"+chanID, respRaw) if err != nil { return fmt.Errorf("%w: unsubscribe:%v", err, chanID) } case wsEventError: if subID, err := jsonparser.GetUnsafeString(respRaw, "subId"); err == nil { - err = e.Websocket.Match.RequireMatchWithData("subscribe:"+subID, respRaw) + err = conn.RequireMatchWithData("subscribe:"+subID, respRaw) if err != nil { return fmt.Errorf("%w: subscribe:%v", err, subID) } } else if chanID, err := jsonparser.GetUnsafeString(respRaw, "chanId"); err == nil { - err = e.Websocket.Match.RequireMatchWithData("unsubscribe:"+chanID, respRaw) + err = conn.RequireMatchWithData("unsubscribe:"+chanID, respRaw) if err != nil { return fmt.Errorf("%w: unsubscribe:%v", err, chanID) } @@ -543,9 +494,12 @@ func (e *Exchange) handleWSEvent(ctx context.Context, respRaw []byte) error { return nil } -// handleWSSubscribed parses a subscription response and registers the chanID key immediately, before updating subscribeToChan via IncomingWithData chan -// wsHandleData happens sequentially, so by rekeying on chanID immediately we ensure the first message is not dropped -func (e *Exchange) handleWSSubscribed(respRaw []byte) error { +// handleWSSubscribed parses a subscription response and transitions a temporary +// subID-keyed subscription to the final chanID key. +// wsHandleData happens sequentially, so by rekeying on chanID immediately we +// ensure the first message is not dropped and manager missing checks can +// validate against the final key. +func (e *Exchange) handleWSSubscribed(conn websocket.Connection, respRaw []byte) error { subID, err := jsonparser.GetUnsafeString(respRaw, "subId") if err != nil { return fmt.Errorf("%w 'subId': %w from message: %s", common.ErrParsingWSField, err, respRaw) @@ -561,24 +515,24 @@ func (e *Exchange) handleWSSubscribed(respRaw []byte) error { return fmt.Errorf("%w: %w 'chanId': %w; Channel: %s Pair: %s", websocket.ErrSubscriptionFailure, common.ErrParsingWSField, err, c.Channel, c.Pairs) } - // Note: chanID's int type avoids conflicts with the string type subID key because of the type difference - c = c.Clone() - c.Key = int(chanID) - - // subscribeToChan removes the old subID keyed Subscription - err = e.Websocket.AddSuccessfulSubscriptions(e.Websocket.Conn, c) + // Transition subID keyed subscription -> chanID keyed subscription. + // Remove sets unsubscribed state, AddSuccessful sets subscribed state. + err = e.Websocket.RemoveSubscriptions(conn, c) + if err != nil { + return fmt.Errorf("%w: %w subID: %s", websocket.ErrSubscriptionFailure, err, subID) + } + c.SetKey(int(chanID)) // chanID type avoids conflicts with string subID keys + err = e.Websocket.AddSuccessfulSubscriptions(conn, c) if err != nil { return fmt.Errorf("%w: %w subID: %s", websocket.ErrSubscriptionFailure, err, subID) } - if e.Verbose { log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s Pair: %s ChannelID: %d\n", e.Name, c.Channel, c.Pairs, chanID) } - - return e.Websocket.Match.RequireMatchWithData("subscribe:"+subID, respRaw) + return conn.RequireMatchWithData("subscribe:"+subID, respRaw) } -func (e *Exchange) handleWSChannelUpdate(ctx context.Context, s *subscription.Subscription, respRaw []byte, eventType string, d []any) error { +func (e *Exchange) handleWSChannelUpdate(ctx context.Context, conn websocket.Connection, s *subscription.Subscription, respRaw []byte, eventType string, d []any) error { if s == nil { return fmt.Errorf("%w: Subscription param", common.ErrNilPointer) } @@ -596,7 +550,7 @@ func (e *Exchange) handleWSChannelUpdate(ctx context.Context, s *subscription.Su switch s.Channel { case subscription.OrderbookChannel: - return e.handleWSBookUpdate(ctx, s, d) + return e.handleWSBookUpdate(ctx, conn, s, d) case subscription.CandlesChannel: return e.handleWSAllCandleUpdates(ctx, s, respRaw) case subscription.TickerChannel: @@ -642,7 +596,7 @@ func (e *Exchange) handleWSChecksum(c *subscription.Subscription, d []any) error return nil } -func (e *Exchange) handleWSBookUpdate(ctx context.Context, c *subscription.Subscription, d []any) error { +func (e *Exchange) handleWSBookUpdate(ctx context.Context, conn websocket.Connection, c *subscription.Subscription, d []any) error { if c == nil { return fmt.Errorf("%w: Subscription param", common.ErrNilPointer) } @@ -738,7 +692,7 @@ func (e *Exchange) handleWSBookUpdate(ctx context.Context, c *subscription.Subsc }) } - if err := e.WsUpdateOrderbook(ctx, c, c.Pairs[0], c.Asset, newOrderbook, int64(sequenceNo), fundingRate); err != nil { + if err := e.WsUpdateOrderbook(ctx, conn, c, c.Pairs[0], c.Asset, newOrderbook, int64(sequenceNo), fundingRate); err != nil { return fmt.Errorf("updating orderbook error: %s", err) } @@ -944,7 +898,7 @@ func (e *Exchange) handleWSPublicTradeUpdate(respRaw []byte) (*Trade, error) { return t, json.Unmarshal(v, t) } -func (e *Exchange) handleWSNotification(ctx context.Context, d []any, respRaw []byte) error { +func (e *Exchange) handleWSNotification(ctx context.Context, conn websocket.Connection, d []any, respRaw []byte) error { notification, ok := d[2].([]any) if !ok { return errors.New("unable to type assert notification data") @@ -960,7 +914,7 @@ func (e *Exchange) handleWSNotification(ctx context.Context, d []any, respRaw [] strings.Contains(channelName, wsFundingOfferCancelRequest): if data[0] != nil { if id, ok := data[0].(float64); ok && id > 0 { - if e.Websocket.Match.IncomingWithData(int64(id), respRaw) { + if conn.IncomingWithData(int64(id), respRaw) { return nil } offer, err := wsHandleFundingOffer(data, true /* include rate real */) @@ -975,7 +929,7 @@ func (e *Exchange) handleWSNotification(ctx context.Context, d []any, respRaw [] if cid, ok := data[2].(float64); !ok { return common.GetTypeAssertError("float64", data[2], channelName+" cid") } else if cid > 0 { - if e.Websocket.Match.IncomingWithData(int64(cid), respRaw) { + if conn.IncomingWithData(int64(cid), respRaw) { return nil } return e.wsHandleOrder(ctx, data) @@ -987,7 +941,7 @@ func (e *Exchange) handleWSNotification(ctx context.Context, d []any, respRaw [] if id, ok := data[0].(float64); !ok { return common.GetTypeAssertError("float64", data[0], channelName+" id") } else if id > 0 { - if e.Websocket.Match.IncomingWithData(int64(id), respRaw) { + if conn.IncomingWithData(int64(id), respRaw) { return nil } return e.wsHandleOrder(ctx, data) @@ -1492,7 +1446,7 @@ func (e *Exchange) WsInsertSnapshot(p currency.Pair, assetType asset.Item, books // WsUpdateOrderbook updates the orderbook list, removing and adding to the // orderbook sides -func (e *Exchange) WsUpdateOrderbook(ctx context.Context, c *subscription.Subscription, p currency.Pair, assetType asset.Item, book []WebsocketBook, sequenceNo int64, fundingRate bool) error { +func (e *Exchange) WsUpdateOrderbook(ctx context.Context, conn websocket.Connection, c *subscription.Subscription, p currency.Pair, assetType asset.Item, book []WebsocketBook, sequenceNo int64, fundingRate bool) error { if c == nil { return fmt.Errorf("%w: Subscription param", common.ErrNilPointer) } @@ -1581,7 +1535,7 @@ func (e *Exchange) WsUpdateOrderbook(ctx context.Context, c *subscription.Subscr if err = validateCRC32(ob, checkme.Token); err != nil { log.Errorf(log.WebsocketMgr, "%s websocket orderbook update error, will resubscribe orderbook: %v", e.Name, err) - if e2 := e.resubOrderbook(ctx, c); e2 != nil { + if e2 := e.resubOrderbook(ctx, conn, c); e2 != nil { log.Errorf(log.WebsocketMgr, "%s error resubscribing orderbook: %v", e.Name, e2) } return err @@ -1594,7 +1548,7 @@ func (e *Exchange) WsUpdateOrderbook(ctx context.Context, c *subscription.Subscr // resubOrderbook resubscribes the orderbook after a consistency error, probably a failed checksum, // which forces a fresh snapshot. If we don't do this the orderbook will keep erroring and drifting. // Flushing the orderbook happens immediately, but the ReSub itself is a go routine to avoid blocking the WS data channel -func (e *Exchange) resubOrderbook(ctx context.Context, c *subscription.Subscription) error { +func (e *Exchange) resubOrderbook(ctx context.Context, conn websocket.Connection, c *subscription.Subscription) error { if c == nil { return fmt.Errorf("%w: Subscription param", common.ErrNilPointer) } @@ -1608,7 +1562,7 @@ func (e *Exchange) resubOrderbook(ctx context.Context, c *subscription.Subscript // Resub will block so we have to do this in a goro go func() { - if err := e.Websocket.ResubscribeToChannel(ctx, e.Websocket.Conn, c); err != nil { + if err := e.Websocket.ResubscribeToChannel(ctx, conn, c); err != nil { log.Errorf(log.ExchangeSys, "%s error resubscribing orderbook: %v", e.Name, err) } }() @@ -1633,36 +1587,32 @@ func (e *Exchange) GetSubscriptionTemplate(_ *subscription.Subscription) (*templ } // ConfigureWS to send checksums and sequence numbers -func (e *Exchange) ConfigureWS(ctx context.Context) error { - return e.Websocket.Conn.SendJSONMessage(ctx, request.Unset, map[string]any{ +func (e *Exchange) ConfigureWS(ctx context.Context, conn websocket.Connection) error { + return conn.SendJSONMessage(ctx, request.Unset, map[string]any{ "event": "conf", "flags": bitfinexChecksumFlag + bitfinexWsSequenceFlag, }) } -// Subscribe sends a websocket message to receive data from channels -func (e *Exchange) Subscribe(subs subscription.List) error { - ctx := context.TODO() - var err error - if subs, err = subs.ExpandTemplates(e); err != nil { - return err +func (e *Exchange) generatePublicSubscriptions() (subscription.List, error) { + subs, err := e.generateSubscriptions() + if err != nil { + return nil, err } - return e.ParallelChanOp(ctx, subs, e.subscribeToChan, 1) + return subs.Public(), nil } -// Unsubscribe sends a websocket message to stop receiving data from channels -func (e *Exchange) Unsubscribe(subs subscription.List) error { - ctx := context.TODO() - var err error - if subs, err = subs.ExpandTemplates(e); err != nil { - return err +func (e *Exchange) generatePrivateSubscriptions() (subscription.List, error) { + subs, err := e.generateSubscriptions() + if err != nil { + return nil, err } - return e.ParallelChanOp(ctx, subs, e.unsubscribeFromChan, 1) + return subs.Private(), nil } // subscribeToChan handles a single subscription and parses the result // on success it adds the subscription to the websocket -func (e *Exchange) subscribeToChan(ctx context.Context, subs subscription.List) error { +func (e *Exchange) subscribeToChan(ctx context.Context, conn websocket.Connection, subs subscription.List) error { if len(subs) != 1 { return subscription.ErrNotSinglePair } @@ -1683,21 +1633,18 @@ func (e *Exchange) subscribeToChan(ctx context.Context, subs subscription.List) // Add a temporary Key so we can find this Sub when we get the resp without delay or context switch // Otherwise we might drop the first messages after the subscribed resp s.Key = subID // Note subID string type avoids conflicts with later chanID key - if err := e.Websocket.AddSubscriptions(e.Websocket.Conn, s); err != nil { + if err := e.Websocket.AddSubscriptions(conn, s); err != nil { return fmt.Errorf("%w Channel: %s Pair: %s", err, s.Channel, s.Pairs) } - // Always remove the temporary subscription keyed by subID - defer func() { - _ = e.Websocket.RemoveSubscriptions(e.Websocket.Conn, s) - }() - - respRaw, err := e.Websocket.Conn.SendMessageReturnResponse(ctx, request.Unset, "subscribe:"+subID, req) + respRaw, err := conn.SendMessageReturnResponse(ctx, request.Unset, "subscribe:"+subID, req) if err != nil { + _ = e.Websocket.RemoveSubscriptions(conn, s) return fmt.Errorf("%w: Channel: %s Pair: %s", err, s.Channel, s.Pairs) } if err = e.getErrResp(respRaw); err != nil { + _ = e.Websocket.RemoveSubscriptions(conn, s) return fmt.Errorf("%w: Channel: %s Pair: %s", err, s.Channel, s.Pairs) } @@ -1705,7 +1652,7 @@ func (e *Exchange) subscribeToChan(ctx context.Context, subs subscription.List) } // unsubscribeFromChan sends a websocket message to stop receiving data from a channel -func (e *Exchange) unsubscribeFromChan(ctx context.Context, subs subscription.List) error { +func (e *Exchange) unsubscribeFromChan(ctx context.Context, conn websocket.Connection, subs subscription.List) error { if len(subs) != 1 { return errors.New("subscription batching limited to 1") } @@ -1720,7 +1667,7 @@ func (e *Exchange) unsubscribeFromChan(ctx context.Context, subs subscription.Li "chanId": chanID, } - respRaw, err := e.Websocket.Conn.SendMessageReturnResponse(ctx, request.Unset, "unsubscribe:"+strconv.Itoa(chanID), req) + respRaw, err := conn.SendMessageReturnResponse(ctx, request.Unset, "unsubscribe:"+strconv.Itoa(chanID), req) if err != nil { return err } @@ -1729,7 +1676,19 @@ func (e *Exchange) unsubscribeFromChan(ctx context.Context, subs subscription.Li return fmt.Errorf("%w: ChanId: %v", err, chanID) } - return e.Websocket.RemoveSubscriptions(e.Websocket.Conn, s) + return e.Websocket.RemoveSubscriptions(conn, s) +} + +func (e *Exchange) subscribeForConnection(ctx context.Context, conn websocket.Connection, subs subscription.List) error { + return e.ParallelChanOp(ctx, subs, func(ctx context.Context, s subscription.List) error { + return e.subscribeToChan(ctx, conn, s) + }, 1) +} + +func (e *Exchange) unsubscribeForConnection(ctx context.Context, conn websocket.Connection, subs subscription.List) error { + return e.ParallelChanOp(ctx, subs, func(ctx context.Context, s subscription.List) error { + return e.unsubscribeFromChan(ctx, conn, s) + }, 1) } // getErrResp takes a json response string and looks for an error event type @@ -1759,8 +1718,7 @@ func (e *Exchange) getErrResp(resp []byte) error { return fmt.Errorf("%w (code: %d)", apiErr, errCode) } -// WsSendAuth sends a authenticated event payload -func (e *Exchange) WsSendAuth(ctx context.Context) error { +func (e *Exchange) wsSendAuthConn(ctx context.Context, conn websocket.Connection) error { creds, err := e.GetCredentials(ctx) if err != nil { return err @@ -1774,7 +1732,7 @@ func (e *Exchange) WsSendAuth(ctx context.Context) error { return err } - return e.Websocket.AuthConn.SendJSONMessage(ctx, request.Unset, WsAuthRequest{ + return conn.SendJSONMessage(ctx, request.Unset, WsAuthRequest{ Event: "auth", APIKey: creds.Key, AuthPayload: payload, @@ -1784,11 +1742,23 @@ func (e *Exchange) WsSendAuth(ctx context.Context) error { }) } +func (e *Exchange) wsAuthConnection() (websocket.Connection, error) { + wsAuthURL, err := e.API.Endpoints.GetURL(exchange.WebsocketSpotSupplementary) + if err != nil { + return nil, err + } + return e.Websocket.GetConnection(wsAuthURL) +} + // WsNewOrder authenticated new order request func (e *Exchange) WsNewOrder(ctx context.Context, data *WsNewOrderRequest) (string, error) { data.CustomID = e.MessageSequence() req := makeRequestInterface(wsOrderNew, data) - resp, err := e.Websocket.AuthConn.SendMessageReturnResponse(ctx, request.Unset, data.CustomID, req) + conn, err := e.wsAuthConnection() + if err != nil { + return "", err + } + resp, err := conn.SendMessageReturnResponse(ctx, request.Unset, data.CustomID, req) if err != nil { return "", err } @@ -1845,7 +1815,11 @@ func (e *Exchange) WsNewOrder(ctx context.Context, data *WsNewOrderRequest) (str // WsModifyOrder authenticated modify order request func (e *Exchange) WsModifyOrder(ctx context.Context, data *WsUpdateOrderRequest) error { req := makeRequestInterface(wsOrderUpdate, data) - resp, err := e.Websocket.AuthConn.SendMessageReturnResponse(ctx, request.Unset, data.OrderID, req) + conn, err := e.wsAuthConnection() + if err != nil { + return err + } + resp, err := conn.SendMessageReturnResponse(ctx, request.Unset, data.OrderID, req) if err != nil { return err } @@ -1890,7 +1864,11 @@ func (e *Exchange) WsCancelMultiOrders(ctx context.Context, orderIDs []int64) er OrderID: orderIDs, } req := makeRequestInterface(wsCancelMultipleOrders, cancel) - return e.Websocket.AuthConn.SendJSONMessage(ctx, request.Unset, req) + conn, err := e.wsAuthConnection() + if err != nil { + return err + } + return conn.SendJSONMessage(ctx, request.Unset, req) } // WsCancelOrder authenticated cancel order request @@ -1899,7 +1877,11 @@ func (e *Exchange) WsCancelOrder(ctx context.Context, orderID int64) error { OrderID: orderID, } req := makeRequestInterface(wsOrderCancel, cancel) - resp, err := e.Websocket.AuthConn.SendMessageReturnResponse(ctx, request.Unset, orderID, req) + conn, err := e.wsAuthConnection() + if err != nil { + return err + } + resp, err := conn.SendMessageReturnResponse(ctx, request.Unset, orderID, req) if err != nil { return err } @@ -1941,13 +1923,21 @@ func (e *Exchange) WsCancelOrder(ctx context.Context, orderID int64) error { func (e *Exchange) WsCancelAllOrders(ctx context.Context) error { cancelAll := WsCancelAllOrdersRequest{All: 1} req := makeRequestInterface(wsCancelMultipleOrders, cancelAll) - return e.Websocket.AuthConn.SendJSONMessage(ctx, request.Unset, req) + conn, err := e.wsAuthConnection() + if err != nil { + return err + } + return conn.SendJSONMessage(ctx, request.Unset, req) } // WsNewOffer authenticated new offer request func (e *Exchange) WsNewOffer(ctx context.Context, data *WsNewOfferRequest) error { req := makeRequestInterface(wsFundingOfferNew, data) - return e.Websocket.AuthConn.SendJSONMessage(ctx, request.Unset, req) + conn, err := e.wsAuthConnection() + if err != nil { + return err + } + return conn.SendJSONMessage(ctx, request.Unset, req) } // WsCancelOffer authenticated cancel offer request @@ -1956,7 +1946,11 @@ func (e *Exchange) WsCancelOffer(ctx context.Context, orderID int64) error { OrderID: orderID, } req := makeRequestInterface(wsFundingOfferCancel, cancel) - resp, err := e.Websocket.AuthConn.SendMessageReturnResponse(ctx, request.Unset, orderID, req) + conn, err := e.wsAuthConnection() + if err != nil { + return err + } + resp, err := conn.SendMessageReturnResponse(ctx, request.Unset, orderID, req) if err != nil { return err } diff --git a/exchanges/bitfinex/bitfinex_wrapper.go b/exchanges/bitfinex/bitfinex_wrapper.go index 548e999c0c0..abcb73949b1 100644 --- a/exchanges/bitfinex/bitfinex_wrapper.go +++ b/exchanges/bitfinex/bitfinex_wrapper.go @@ -161,8 +161,9 @@ func (e *Exchange) SetDefaults() { } e.API.Endpoints = e.NewEndpoints() err = e.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ - exchange.RestSpot: bitfinexAPIURLBase, - exchange.WebsocketSpot: publicBitfinexWebsocketEndpoint, + exchange.RestSpot: bitfinexAPIURLBase, + exchange.WebsocketSpot: publicBitfinexWebsocketEndpoint, + exchange.WebsocketSpotSupplementary: authenticatedBitfinexWebsocketEndpoint, }) if err != nil { log.Errorln(log.ExchangeSys, err) @@ -188,39 +189,51 @@ func (e *Exchange) Setup(exch *config.Exchange) error { return err } - wsEndpoint, err := e.API.Endpoints.GetURL(exchange.WebsocketSpot) + err = e.Websocket.Setup(&websocket.ManagerSetup{ + ExchangeConfig: exch, + Features: &e.Features.Supports.WebsocketCapabilities, + UseMultiConnectionManagement: true, + MaxWebsocketSubscriptionsPerConnection: 25, // https://docs.bitfinex.com/docs/requirements-and-limitations + }) if err != nil { return err } - - err = e.Websocket.Setup(&websocket.ManagerSetup{ - ExchangeConfig: exch, - DefaultURL: publicBitfinexWebsocketEndpoint, - RunningURL: wsEndpoint, - Connector: e.WsConnect, - Subscriber: e.Subscribe, - Unsubscriber: e.Unsubscribe, - GenerateSubscriptions: e.generateSubscriptions, - Features: &e.Features.Supports.WebsocketCapabilities, - }) + wsPublicURL, err := e.API.Endpoints.GetURL(exchange.WebsocketSpot) + if err != nil { + return err + } + wsAuthURL, err := e.API.Endpoints.GetURL(exchange.WebsocketSpotSupplementary) if err != nil { return err } err = e.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - URL: publicBitfinexWebsocketEndpoint, + Connector: e.wsConnect, + Subscriber: e.subscribeForConnection, + Unsubscriber: e.unsubscribeForConnection, + GenerateSubscriptions: e.generatePublicSubscriptions, + Handler: e.wsHandleData, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + URL: wsPublicURL, + MessageFilter: wsPublicURL, }) if err != nil { return err } return e.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - URL: authenticatedBitfinexWebsocketEndpoint, - Authenticated: true, + Connector: e.wsConnect, + Authenticate: e.wsSendAuthConn, + Subscriber: e.subscribeForConnection, + Unsubscriber: e.unsubscribeForConnection, + GenerateSubscriptions: e.generatePrivateSubscriptions, + SubscriptionsNotRequired: true, + Handler: e.wsHandleData, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + URL: wsAuthURL, + MessageFilter: wsAuthURL, }) } @@ -961,11 +974,6 @@ func (e *Exchange) GetOrderHistory(ctx context.Context, req *order.MultiOrderReq return req.Filter(e.Name, orders), nil } -// AuthenticateWebsocket sends an authentication message to the websocket -func (e *Exchange) AuthenticateWebsocket(ctx context.Context) error { - return e.WsSendAuth(ctx) -} - // appendOptionalDelimiter ensures that a delimiter is present for long character currencies func (e *Exchange) appendOptionalDelimiter(p *currency.Pair) { if (len(p.Base.String()) > 3 && !p.Quote.IsEmpty()) || diff --git a/exchanges/bitstamp/bitstamp_wrapper.go b/exchanges/bitstamp/bitstamp_wrapper.go index ecedcb93aa7..d82d284ae57 100644 --- a/exchanges/bitstamp/bitstamp_wrapper.go +++ b/exchanges/bitstamp/bitstamp_wrapper.go @@ -167,7 +167,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { } return e.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ - URL: e.Websocket.GetWebsocketURL(), + URL: wsURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, }) diff --git a/exchanges/bybit/bybit_test.go b/exchanges/bybit/bybit_test.go index 5d4a43440c6..54e582cb1af 100644 --- a/exchanges/bybit/bybit_test.go +++ b/exchanges/bybit/bybit_test.go @@ -3041,8 +3041,9 @@ func TestWSHandleData(t *testing.T) { keys := slices.Collect(maps.Keys(pushDataMap)) slices.Sort(keys) + conn := testexch.GetMockConn(t, e, "") for x := range keys { - err := e.wsHandleData(t.Context(), nil, asset.Spot, []byte(pushDataMap[keys[x]])) + err := e.wsHandleData(t.Context(), conn, asset.Spot, []byte(pushDataMap[keys[x]])) if keys[x] == "unhandled" { assert.ErrorIs(t, err, errUnhandledStreamData, "wsHandleData should error correctly for unhandled topics") } else { @@ -3254,12 +3255,10 @@ func TestWsTicker(t *testing.T) { asset.Spot, asset.Options, asset.USDTMarginedFutures, asset.USDTMarginedFutures, asset.USDCMarginedFutures, asset.USDCMarginedFutures, asset.CoinMarginedFutures, asset.CoinMarginedFutures, } - routingIndex := 0 + conn := testexch.GetMockConn(t, e, "") testexch.FixtureToDataHandler(t, "testdata/wsTicker.json", func(_ context.Context, r []byte) error { - require.Less(t, routingIndex, len(assetRouting), "routingIndex must stay within ticker fixture asset routing bounds") - a := assetRouting[routingIndex] - routingIndex++ - return e.wsHandleData(t.Context(), nil, a, r) + defer slices.Delete(assetRouting, 0, 1) + return e.wsHandleData(t.Context(), conn, assetRouting[0], r) }) e.Websocket.DataHandler.Close() expected := 8 @@ -3508,7 +3507,7 @@ func TestFetchTradablePairs(t *testing.T) { func TestDeltaUpdateOrderbook(t *testing.T) { t.Parallel() data := []byte(`{"topic":"orderbook.50.WEMIXUSDT","ts":1697573183768,"type":"snapshot","data":{"s":"WEMIXUSDT","b":[["0.9511","260.703"],["0.9677","0"]],"a":[],"u":3119516,"seq":14126848493},"cts":1728966699481}`) - err := e.wsHandleData(t.Context(), nil, asset.Spot, data) + err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), asset.Spot, data) require.NoError(t, err, "wsHandleData must not error") update := []byte(`{"topic":"orderbook.50.WEMIXUSDT","ts":1697573183768,"type":"delta","data":{"s":"WEMIXUSDT","b":[["0.9511","260.703"],["0.9677","0"]],"a":[],"u":3119516,"seq":14126848493},"cts":1728966699481}`) var wsResponse WebsocketResponse diff --git a/exchanges/bybit/bybit_websocket_requests.go b/exchanges/bybit/bybit_websocket_requests.go index be19d65ec0a..521fad5058a 100644 --- a/exchanges/bybit/bybit_websocket_requests.go +++ b/exchanges/bybit/bybit_websocket_requests.go @@ -8,15 +8,10 @@ import ( "github.com/gofrs/uuid" "github.com/thrasher-corp/gocryptotrader/encoding/json" + exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/request" ) -// Websocket request operation types -const ( - OutboundTradeConnection = "PRIVATE_TRADE" - InboundPrivateConnection = "PRIVATE" -) - // WSCreateOrder creates an order through the websocket connection func (e *Exchange) WSCreateOrder(ctx context.Context, r *PlaceOrderRequest) (*WebsocketOrderDetails, error) { if err := r.Validate(); err != nil { @@ -64,13 +59,21 @@ func (e *Exchange) WSCancelOrder(ctx context.Context, r *CancelOrderRequest) (*W // sendWebsocketTradeRequest sends a trade request to the exchange through the websocket connection func (e *Exchange) sendWebsocketTradeRequest(ctx context.Context, op, orderLinkID string, payload any, limit request.EndpointLimit) (*WebsocketOrderDetails, error) { + wsTradeURL, err := e.API.Endpoints.GetURL(exchange.WebsocketTrade) + if err != nil { + return nil, err + } + wsPrivateURL, err := e.API.Endpoints.GetURL(exchange.WebsocketPrivate) + if err != nil { + return nil, err + } // Get the outbound and inbound connections to send and receive the request. This makes sure both are live before // sending the request. - outbound, err := e.Websocket.GetConnection(OutboundTradeConnection) + outbound, err := e.Websocket.GetConnection(wsTradeURL) if err != nil { return nil, err } - inbound, err := e.Websocket.GetConnection(InboundPrivateConnection) + inbound, err := e.Websocket.GetConnection(wsPrivateURL) if err != nil { return nil, err } diff --git a/exchanges/bybit/bybit_wrapper.go b/exchanges/bybit/bybit_wrapper.go index 9639d512b2b..070ac0f1eff 100644 --- a/exchanges/bybit/bybit_wrapper.go +++ b/exchanges/bybit/bybit_wrapper.go @@ -303,7 +303,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, conn websocket.Connection, resp []byte) error { return e.wsHandleData(ctx, conn, asset.USDTMarginedFutures, resp) }, - MessageFilter: asset.USDTMarginedFutures, // Unused but it allows us to differentiate between the two linear futures types. + MessageFilter: asset.USDTMarginedFutures, // Required to differentiate linear futures connections sharing the same endpoint URL. }); err != nil { return err } @@ -331,7 +331,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, conn websocket.Connection, resp []byte) error { return e.wsHandleData(ctx, conn, asset.USDCMarginedFutures, resp) }, - MessageFilter: asset.USDCMarginedFutures, // Unused but it allows us to differentiate between the two linear futures types. + MessageFilter: asset.USDCMarginedFutures, // Required to differentiate linear futures connections sharing the same endpoint URL. }); err != nil { return err } @@ -372,7 +372,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { return e.wsHandleTradeData(conn, resp) }, Authenticate: e.WebsocketAuthenticateTradeConnection, - MessageFilter: OutboundTradeConnection, + MessageFilter: wsTradeURL, SubscriptionsNotRequired: true, }); err != nil { return err @@ -395,7 +395,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { Unsubscriber: e.authUnsubscribe, Handler: e.wsHandleAuthenticatedData, Authenticate: e.WebsocketAuthenticatePrivateConnection, - MessageFilter: InboundPrivateConnection, + MessageFilter: wsPrivateURL, }) } diff --git a/exchanges/coinbase/coinbase_test.go b/exchanges/coinbase/coinbase_test.go index 46f0ec88a78..9c7dda347b0 100644 --- a/exchanges/coinbase/coinbase_test.go +++ b/exchanges/coinbase/coinbase_test.go @@ -2,24 +2,17 @@ package coinbase import ( "context" - "errors" "log" - "net/http" "os" "strconv" - "strings" "testing" "time" "github.com/gofrs/uuid" - gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" - "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" - "github.com/thrasher-corp/gocryptotrader/encoding/json" - "github.com/thrasher-corp/gocryptotrader/exchange/websocket" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/fundingrate" @@ -29,7 +22,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" - testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -103,18 +95,6 @@ func TestSetup(t *testing.T) { assert.ErrorIs(t, err, exchange.ErrSettingProxyAddress) } -func TestWsConnect(t *testing.T) { - sharedtestvalues.SkipTestIfCredentialsUnset(t, e) - exch := &Exchange{} - exch.Websocket = sharedtestvalues.NewTestWebsocket() - err := exch.WsConnect() - assert.ErrorIs(t, err, websocket.ErrWebsocketNotEnabled) - err = exchangeBaseHelper(exch) - require.NoError(t, err) - err = exch.Websocket.Enable(t.Context()) - assert.NoError(t, err) -} - func TestGetAccountByID(t *testing.T) { t.Parallel() _, err := e.GetAccountByID(t.Context(), "") @@ -1506,224 +1486,6 @@ func TestGetCurrencyTradeURL(t *testing.T) { } } -// TestWsAuth dials websocket, sends login request. -func TestWsAuth(t *testing.T) { - p := currency.Pairs{testPairFiat} - testexch.SkipTestIfCannotUseAuthenticatedWebsocket(t, e) - var dialer gws.Dialer - err := e.Websocket.Conn.Dial(t.Context(), &dialer, http.Header{}, nil) - require.NoError(t, err) - e.Websocket.Wg.Add(1) - go e.wsReadData(t.Context()) - err = e.Subscribe(subscription.List{ - { - Channel: "myAccount", - Asset: asset.All, - Pairs: p, - Authenticated: true, - }, - }) - assert.NoError(t, err) - timer := time.NewTimer(sharedtestvalues.WebsocketResponseDefaultTimeout) - select { - case badResponse := <-e.Websocket.DataHandler.C: - assert.IsType(t, []order.Detail{}, badResponse) - case <-timer.C: - } - timer.Stop() -} - -func TestWsHandleData(t *testing.T) { - done := make(chan struct{}) - t.Cleanup(func() { - close(done) - }) - go func() { - for { - select { - case <-e.Websocket.DataHandler.C: - continue - case <-done: - return - } - } - }() - _, err := e.wsHandleData(t.Context(), nil) - var syntaxErr *json.SyntaxError - assert.True(t, errors.As(err, &syntaxErr) || strings.Contains(err.Error(), "Syntax error no sources available, the input json is empty"), errJSONUnmarshalUnexpected) - mockJSON := []byte(`{"type": "error"}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.Error(t, err) - mockJSON = []byte(`{"sequence_num": 0, "channel": "subscriptions"}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.NoError(t, err) - var unmarshalTypeErr *json.UnmarshalTypeError - mockJSON = []byte(`{"sequence_num": 0, "channel": "status", "events": [{"type": 1234}]}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.True(t, errors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) - mockJSON = []byte(`{"sequence_num": 0, "channel": "status", "events": [{"type": "moo"}]}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.NoError(t, err) - mockJSON = []byte(`{"sequence_num": 0, "channel": "ticker", "events": [{"type": "moo", "tickers": false}]}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.True(t, errors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) - mockJSON = []byte(`{"sequence_num": 0, "channel": "candles", "events": [{"type": false}]}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.True(t, errors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) - mockJSON = []byte(`{"sequence_num": 0, "channel": "candles", "events": [{"type": "moo", "candles": [{"low": "1.1"}]}]}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.NoError(t, err) - mockJSON = []byte(`{"sequence_num": 0, "channel": "market_trades", "events": [{"type": false}]}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.True(t, errors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) - mockJSON = []byte(`{"sequence_num": 0, "channel": "market_trades", "events": [{"type": "moo", "trades": [{"price": "1.1"}]}]}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.NoError(t, err) - mockJSON = []byte(`{"sequence_num": 0, "channel": "l2_data", "events": [{"type": false, "updates": [{"price_level": "1.1"}]}]}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.True(t, errors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) - mockJSON = []byte(`{"sequence_num": 0, "channel": "l2_data", "timestamp": "2006-01-02T15:04:05Z", "events": [{"type": "moo", "updates": [{"price_level": "1.1"}]}]}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.ErrorIs(t, err, errUnknownL2DataType) - mockJSON = []byte(`{"sequence_num": 0, "channel": "l2_data", "timestamp": "2006-01-02T15:04:05Z", "events": [{"type": "snapshot", "product_id": "BTC-USD", "updates": [{"side": "bid", "price_level": "1.1", "new_quantity": "2.2"}]}]}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.NoError(t, err) - mockJSON = []byte(`{"sequence_num": 0, "channel": "l2_data", "timestamp": "2006-01-02T15:04:05Z", "events": [{"type": "update", "product_id": "BTC-USD", "updates": [{"side": "bid", "price_level": "1.1", "new_quantity": "2.2"}]}]}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.NoError(t, err) - mockJSON = []byte(`{"sequence_num": 0, "channel": "user", "events": [{"type": false}]}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.True(t, errors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) - mockJSON = []byte(`{"sequence_num": 0, "channel": "user", "events": [{"type": "l", "orders": [{"limit_price": "2.2", "total_fees": "1.1", "post_only": true}], "positions": {"perpetual_futures_positions": [{"margin_type": "fakeMarginType"}], "expiring_futures_positions": [{}]}}]}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.ErrorIs(t, err, order.ErrUnrecognisedOrderType) - mockJSON = []byte(`{"sequence_num": 0, "channel": "fakechan", "events": [{"type": ""}]}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.ErrorIs(t, err, errChannelNameUnknown) - p, err := e.FormatExchangeCurrency(currency.NewBTCUSD(), asset.Spot) - require.NoError(t, err) - e.pairAliases.Load(map[currency.Pair]currency.Pairs{ - p: {p}, - }) - mockJSON = []byte(`{"sequence_num": 0, "channel": "ticker", "events": [{"type": "moo", "tickers": [{"product_id": "BTC-USD", "price": "1.1"}]}]}`) - _, err = e.wsHandleData(t.Context(), mockJSON) - assert.NoError(t, err) -} - -func TestWsProcessCandleIntervalMapping(t *testing.T) { - t.Parallel() - ex := new(Exchange) - require.NoError(t, testexch.Setup(ex), "Setup instance must not error") - - resp := &StandardWebsocketResponse{ - Channel: "candles", - Events: json.RawMessage(`[{"type":"snapshot","candles":[{"start":1704067200,"low":"99.5","high":"101.0","open":"100.0","close":"100.5","volume":"12.3","product_id":"BTC-USD"}]}]`), - } - require.NoError(t, ex.wsProcessCandle(t.Context(), resp)) - - select { - case msg := <-ex.Websocket.DataHandler.C: - got, ok := msg.Data.([]kline.Item) - require.True(t, ok, "expected []kline.Item") - assert.Equal(t, []kline.Item{{ - Pair: currency.NewPairWithDelimiter("BTC", "USD", "-"), - Asset: asset.Spot, - Exchange: ex.Name, - Interval: kline.FiveMin, - Candles: []kline.Candle{{ - Time: time.Unix(1704067200, 0), - Open: 100, - Close: 100.5, - High: 101, - Low: 99.5, - Volume: 12.3, - }}, - }}, got) - default: - require.Fail(t, "expected websocket candle payload") - } -} - -func TestProcessSnapshotUpdate(t *testing.T) { - t.Parallel() - req := WebsocketOrderbookDataHolder{Changes: []WebsocketOrderbookData{{Side: "fakeside", PriceLevel: 1.1, NewQuantity: 2.2}}, ProductID: currency.NewBTCUSD()} - err := e.ProcessSnapshot(&req, time.Time{}) - assert.ErrorIs(t, err, order.ErrSideIsInvalid) - err = e.ProcessUpdate(&req, time.Time{}) - assert.ErrorIs(t, err, order.ErrSideIsInvalid) - req.Changes[0].Side = "offer" - err = e.ProcessSnapshot(&req, time.Now()) - assert.NoError(t, err) - err = e.ProcessUpdate(&req, time.Now()) - assert.NoError(t, err) -} - -func TestGenerateSubscriptions(t *testing.T) { - t.Parallel() - e := new(Exchange) - if err := testexch.Setup(e); err != nil { - log.Fatal(err) - } - e.Websocket.SetCanUseAuthenticatedEndpoints(true) - p1, err := e.GetEnabledPairs(asset.Spot) - require.NoError(t, err) - p2, err := e.GetEnabledPairs(asset.Futures) - require.NoError(t, err) - exp := subscription.List{} - for _, baseSub := range defaultSubscriptions.Enabled() { - s := baseSub.Clone() - s.QualifiedChannel = subscriptionNames[s.Channel] - switch s.Asset { - case asset.Spot: - s.Pairs = p1 - case asset.Futures: - s.Pairs = p2 - case asset.All: - s2 := s.Clone() - s2.Asset = asset.Futures - s2.Pairs = p2 - exp = append(exp, s2) - s.Asset = asset.Spot - s.Pairs = p1 - } - exp = append(exp, s) - } - subs, err := e.generateSubscriptions() - require.NoError(t, err) - testsubs.EqualLists(t, exp, subs) - _, err = subscription.List{{Channel: "wibble"}}.ExpandTemplates(e) - assert.ErrorContains(t, err, "subscription channel not supported: wibble") -} - -func TestSubscribeUnsubscribe(t *testing.T) { - t.Parallel() - sharedtestvalues.SkipTestIfCredentialsUnset(t, e) - req := subscription.List{{Channel: "heartbeat", Asset: asset.Spot, Pairs: currency.Pairs{currency.NewPairWithDelimiter(testCrypto.String(), testFiat.String(), "-")}}} - err := e.Subscribe(req) - assert.NoError(t, err) - err = e.Unsubscribe(req) - assert.NoError(t, err) -} - -func TestCheckSubscriptions(t *testing.T) { - t.Parallel() - e := &Exchange{ - Base: exchange.Base{ - Config: &config.Exchange{ - Features: &config.FeaturesConfig{ - Subscriptions: subscription.List{ - {Enabled: true, Channel: "matches"}, - }, - }, - }, - Features: exchange.Features{}, - }, - } - e.checkSubscriptions() - testsubs.EqualLists(t, defaultSubscriptions.Enabled(), e.Features.Subscriptions) - testsubs.EqualLists(t, defaultSubscriptions, e.Config.Features.Subscriptions) -} - func TestGetJWT(t *testing.T) { t.Parallel() sharedtestvalues.SkipTestIfCredentialsUnset(t, e) diff --git a/exchanges/coinbase/coinbase_types.go b/exchanges/coinbase/coinbase_types.go index 8d741c12df8..65a366d2e43 100644 --- a/exchanges/coinbase/coinbase_types.go +++ b/exchanges/coinbase/coinbase_types.go @@ -8,6 +8,7 @@ import ( "github.com/gofrs/uuid" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" + "github.com/thrasher-corp/gocryptotrader/exchange/websocket" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/types" @@ -29,6 +30,8 @@ type Exchange struct { exchange.Base jwt jwtManager pairAliases pairAliases + wsSeqState map[websocket.Connection]uint64 + wsSeqMu sync.Mutex } // Version is used for the niche cases where the Version of the API must be specified and passed around for proper functionality diff --git a/exchanges/coinbase/coinbase_websocket.go b/exchanges/coinbase/coinbase_websocket.go index 24b6a6914a6..96b49ee0010 100644 --- a/exchanges/coinbase/coinbase_websocket.go +++ b/exchanges/coinbase/coinbase_websocket.go @@ -22,7 +22,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" - "github.com/thrasher-corp/gocryptotrader/log" ) const ( @@ -57,47 +56,8 @@ var defaultSubscriptions = subscription.List{ */ } -// WsConnect initiates a websocket connection -func (e *Exchange) WsConnect() error { - ctx := context.TODO() - if !e.Websocket.IsEnabled() || !e.IsEnabled() { - return websocket.ErrWebsocketNotEnabled - } - var dialer gws.Dialer - if err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{}, nil); err != nil { - return err - } - e.Websocket.Wg.Add(1) - go e.wsReadData(ctx) - return nil -} - -// wsReadData receives and passes on websocket messages for processing -func (e *Exchange) wsReadData(ctx context.Context) { - defer e.Websocket.Wg.Done() - var seqCount uint64 - for { - resp := e.Websocket.Conn.ReadMessage() - if resp.Raw == nil { - return - } - sequence, err := e.wsHandleData(ctx, resp.Raw) - if err != nil { - if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { - log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) - } - } - if sequence != nil { - if *sequence != seqCount { - err := fmt.Errorf("%w: received %v, expected %v", errOutOfSequence, sequence, seqCount) - if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { - log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) - } - seqCount = *sequence - } - seqCount++ - } - } +func (e *Exchange) wsConnect(ctx context.Context, conn websocket.Connection) error { + return conn.Dial(ctx, &gws.Dialer{}, http.Header{}, nil) } // wsProcessTicker handles ticker data from the websocket @@ -307,47 +267,58 @@ func (e *Exchange) wsProcessUser(ctx context.Context, resp *StandardWebsocketRes } // wsHandleData handles all the websocket data coming from the websocket connection -func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) (*uint64, error) { +func (e *Exchange) wsHandleData(ctx context.Context, conn websocket.Connection, respRaw []byte) error { var resp StandardWebsocketResponse if err := json.Unmarshal(respRaw, &resp); err != nil { - return nil, err + return err + } + if err := e.checkWSSequence(conn, resp.Sequence); err != nil { + return err } if resp.Error != "" { - return &resp.Sequence, errors.New(resp.Error) + return errors.New(resp.Error) } switch resp.Channel { case "subscriptions", "heartbeats": - return &resp.Sequence, nil + return nil case "status": var wsStatus []WebsocketProductHolder if err := json.Unmarshal(resp.Events, &wsStatus); err != nil { - return &resp.Sequence, err + return err } - return &resp.Sequence, e.Websocket.DataHandler.Send(ctx, wsStatus) + return e.Websocket.DataHandler.Send(ctx, wsStatus) case "ticker", "ticker_batch": - if err := e.wsProcessTicker(ctx, &resp); err != nil { - return &resp.Sequence, err - } + return e.wsProcessTicker(ctx, &resp) case "candles": - if err := e.wsProcessCandle(ctx, &resp); err != nil { - return &resp.Sequence, err - } + return e.wsProcessCandle(ctx, &resp) case "market_trades": - if err := e.wsProcessMarketTrades(ctx, &resp); err != nil { - return &resp.Sequence, err - } + return e.wsProcessMarketTrades(ctx, &resp) case "l2_data": - if err := e.wsProcessL2(&resp); err != nil { - return &resp.Sequence, err - } + return e.wsProcessL2(&resp) case "user": - if err := e.wsProcessUser(ctx, &resp); err != nil { - return &resp.Sequence, err - } + return e.wsProcessUser(ctx, &resp) default: - return &resp.Sequence, errChannelNameUnknown + return errChannelNameUnknown } - return &resp.Sequence, nil +} + +func (e *Exchange) checkWSSequence(conn websocket.Connection, sequence uint64) error { + e.wsSeqMu.Lock() + defer e.wsSeqMu.Unlock() + if e.wsSeqState == nil { + e.wsSeqState = make(map[websocket.Connection]uint64) + } + expected, ok := e.wsSeqState[conn] + if !ok { + e.wsSeqState[conn] = sequence + 1 + return nil + } + if sequence != expected { + e.wsSeqState[conn] = sequence + 1 + return fmt.Errorf("%w: received %v, expected %v", errOutOfSequence, sequence, expected) + } + e.wsSeqState[conn] = expected + 1 + return nil } // ProcessSnapshot processes the initial orderbook snap shot @@ -418,18 +389,8 @@ func (e *Exchange) GetSubscriptionTemplate(_ *subscription.Subscription) (*templ return template.New("master.tmpl").Funcs(template.FuncMap{"channelName": channelName}).Parse(subTplText) } -// Subscribe sends a websocket message to receive data from a list of channels -func (e *Exchange) Subscribe(subs subscription.List) error { - return e.ParallelChanOp(context.TODO(), subs, func(ctx context.Context, subs subscription.List) error { return e.manageSubs(ctx, "subscribe", subs) }, 1) -} - -// Unsubscribe sends a websocket message to stop receiving data from a list of channels -func (e *Exchange) Unsubscribe(subs subscription.List) error { - return e.ParallelChanOp(context.TODO(), subs, func(ctx context.Context, subs subscription.List) error { return e.manageSubs(ctx, "unsubscribe", subs) }, 1) -} - // manageSubs subscribes or unsubscribes from a list of websocket channels -func (e *Exchange) manageSubs(ctx context.Context, op string, subs subscription.List) error { +func (e *Exchange) manageSubs(ctx context.Context, conn websocket.Connection, op string, subs subscription.List) error { var errs error subs, errs = subs.ExpandTemplates(e) for _, s := range subs { @@ -447,12 +408,12 @@ func (e *Exchange) manageSubs(ctx context.Context, op string, subs subscription. return err } } - if err = e.Websocket.Conn.SendJSONMessage(ctx, limitType, r); err == nil { + if err = conn.SendJSONMessage(ctx, limitType, r); err == nil { switch op { case "subscribe": - err = e.Websocket.AddSuccessfulSubscriptions(e.Websocket.Conn, s) + err = e.Websocket.AddSuccessfulSubscriptions(conn, s) case "unsubscribe": - err = e.Websocket.RemoveSubscriptions(e.Websocket.Conn, s) + err = e.Websocket.RemoveSubscriptions(conn, s) } } errs = common.AppendError(errs, err) @@ -460,6 +421,14 @@ func (e *Exchange) manageSubs(ctx context.Context, op string, subs subscription. return errs } +func (e *Exchange) subscribeForConnection(ctx context.Context, conn websocket.Connection, subs subscription.List) error { + return e.manageSubs(ctx, conn, "subscribe", subs) +} + +func (e *Exchange) unsubscribeForConnection(ctx context.Context, conn websocket.Connection, subs subscription.List) error { + return e.manageSubs(ctx, conn, "unsubscribe", subs) +} + // GetWSJWT returns a JWT, using a stored one of it's provided, and generating a new one otherwise func (e *Exchange) GetWSJWT(ctx context.Context) (string, error) { e.jwt.m.RLock() diff --git a/exchanges/coinbase/coinbase_websocket_test.go b/exchanges/coinbase/coinbase_websocket_test.go new file mode 100644 index 00000000000..38cb7b909ec --- /dev/null +++ b/exchanges/coinbase/coinbase_websocket_test.go @@ -0,0 +1,587 @@ +package coinbase + +import ( + "context" + stderrors "errors" + "log" + "strconv" + "strings" + "testing" + "text/template" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thrasher-corp/gocryptotrader/config" + "github.com/thrasher-corp/gocryptotrader/currency" + gctjson "github.com/thrasher-corp/gocryptotrader/encoding/json" + "github.com/thrasher-corp/gocryptotrader/exchange/websocket" + exchange "github.com/thrasher-corp/gocryptotrader/exchanges" + "github.com/thrasher-corp/gocryptotrader/exchanges/asset" + "github.com/thrasher-corp/gocryptotrader/exchanges/kline" + "github.com/thrasher-corp/gocryptotrader/exchanges/order" + "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" + testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" +) + +func TestWsConnect(t *testing.T) { + t.Parallel() + exch := &Exchange{} + exch.Websocket = sharedtestvalues.NewTestWebsocket() + err := exch.Websocket.Connect(t.Context()) + assert.ErrorIs(t, err, websocket.ErrWebsocketNotEnabled) + err = exchangeBaseHelper(exch) + require.NoError(t, err) + err = exch.Websocket.Enable(t.Context()) + assert.NoError(t, err) +} + +func TestWsHandleData(t *testing.T) { + t.Parallel() + t.Run("nil message", func(t *testing.T) { + t.Parallel() + err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), nil) + var syntaxErr *gctjson.SyntaxError + assert.True(t, stderrors.As(err, &syntaxErr) || strings.Contains(err.Error(), "Syntax error no sources available, the input json is empty"), errJSONUnmarshalUnexpected) + }) + + t.Run("error type message", func(t *testing.T) { + t.Parallel() + mockJSON := []byte(`{"type": "error"}`) + err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), mockJSON) + assert.Error(t, err) + }) + + t.Run("subscriptions channel", func(t *testing.T) { + t.Parallel() + + mockJSON := []byte(`{"sequence_num": 0, "channel": "subscriptions"}`) + err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), mockJSON) + assert.NoError(t, err) + }) + + t.Run("heartbeats channel", func(t *testing.T) { + t.Parallel() + mockJSON := []byte(`{"sequence_num": 0, "channel": "heartbeats"}`) + err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), mockJSON) + assert.NoError(t, err) + }) + + t.Run("status channel success", func(t *testing.T) { + t.Parallel() + mockJSON := []byte(`{"sequence_num": 0, "channel": "status", "events": [{"type": "status", "products": []}]}`) + err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), mockJSON) + assert.NoError(t, err) + }) + + t.Run("status events type unmarshal", func(t *testing.T) { + t.Parallel() + var unmarshalTypeErr *gctjson.UnmarshalTypeError + mockJSON := []byte(`{"sequence_num": 0, "channel": "status", "events": [{"type": 1234}]}`) + err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), mockJSON) + assert.True(t, stderrors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) + }) + + t.Run("ticker tickers unmarshal", func(t *testing.T) { + t.Parallel() + var unmarshalTypeErr *gctjson.UnmarshalTypeError + mockJSON := []byte(`{"sequence_num": 0, "channel": "ticker", "events": [{"type": "moo", "tickers": false}]}`) + err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), mockJSON) + assert.True(t, stderrors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) + }) + + t.Run("candles events type unmarshal", func(t *testing.T) { + t.Parallel() + var unmarshalTypeErr *gctjson.UnmarshalTypeError + mockJSON := []byte(`{"sequence_num": 0, "channel": "candles", "events": [{"type": false}]}`) + err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), mockJSON) + assert.True(t, stderrors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) + }) + + t.Run("market_trades events type unmarshal", func(t *testing.T) { + t.Parallel() + var unmarshalTypeErr *gctjson.UnmarshalTypeError + mockJSON := []byte(`{"sequence_num": 0, "channel": "market_trades", "events": [{"type": false}]}`) + err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), mockJSON) + assert.True(t, stderrors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) + }) + + t.Run("l2_data updates unmarshal", func(t *testing.T) { + t.Parallel() + var unmarshalTypeErr *gctjson.UnmarshalTypeError + mockJSON := []byte(`{"sequence_num": 0, "channel": "l2_data", "events": [{"type": false, "updates": [{"price_level": "1.1"}]}]}`) + err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), mockJSON) + assert.True(t, stderrors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) + }) + + t.Run("user events type unmarshal", func(t *testing.T) { + t.Parallel() + var unmarshalTypeErr *gctjson.UnmarshalTypeError + mockJSON := []byte(`{"sequence_num": 0, "channel": "user", "events": [{"type": false}]}`) + err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), mockJSON) + assert.True(t, stderrors.As(err, &unmarshalTypeErr) || strings.Contains(err.Error(), "mismatched type with value"), errJSONUnmarshalUnexpected) + }) + + t.Run("unknown channel", func(t *testing.T) { + t.Parallel() + mockJSON := []byte(`{"sequence_num": 0, "channel": "fakechan", "events": [{"type": ""}]}`) + err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), mockJSON) + assert.ErrorIs(t, err, errChannelNameUnknown) + }) + + t.Run("sequence validation before payload error", func(t *testing.T) { + t.Parallel() + ex := new(Exchange) + require.NoError(t, testexch.Setup(ex)) + conn := testexch.GetMockConn(t, ex, "ws://coinbase-wshandledata-seq") + assert.NoError(t, ex.wsHandleData(t.Context(), conn, []byte(`{"sequence_num": 1, "channel": "subscriptions"}`))) + err := ex.wsHandleData(t.Context(), conn, []byte(`{"sequence_num": 3, "channel": "subscriptions", "type": "error"}`)) + assert.ErrorIs(t, err, errOutOfSequence) + }) + + t.Run("ticker with alias loaded", func(t *testing.T) { + t.Parallel() + ex := new(Exchange) + require.NoError(t, testexch.Setup(ex)) + p, err := ex.FormatExchangeCurrency(currency.NewBTCUSD(), asset.Spot) + require.NoError(t, err) + ex.pairAliases.Load(map[currency.Pair]currency.Pairs{p: {p}}) + mockJSON := []byte(`{"sequence_num": 0, "channel": "ticker", "events": [{"type": "moo", "tickers": [{"product_id": "BTC-USD", "price": "1.1"}]}]}`) + err = ex.wsHandleData(t.Context(), testexch.GetMockConn(t, ex, ""), mockJSON) + assert.NoError(t, err) + }) +} + +func TestWsHandleDataSequence(t *testing.T) { + t.Parallel() + connA := testexch.GetMockConn(t, e, "ws://coinbase-seq-a") + connB := testexch.GetMockConn(t, e, "ws://coinbase-seq-b") + buildSubMsg := func(seq uint64) []byte { + return []byte(`{"sequence_num":` + strconv.FormatUint(seq, 10) + `,"channel":"subscriptions"}`) + } + + assert.NoError(t, e.wsHandleData(t.Context(), connA, buildSubMsg(7)), "wsHandleData should not error for initial sequence") + assert.NoError(t, e.wsHandleData(t.Context(), connA, buildSubMsg(8)), "wsHandleData should not error for in-order sequence") + assert.ErrorIs(t, e.wsHandleData(t.Context(), connA, buildSubMsg(10)), errOutOfSequence, "wsHandleData should error for out-of-order sequence") + assert.NoError(t, e.wsHandleData(t.Context(), connA, buildSubMsg(11)), "wsHandleData should not error after sequence state is resynced") + assert.NoError(t, e.wsHandleData(t.Context(), connB, buildSubMsg(3)), "wsHandleData should not error for a different connection sequence state") +} + +func TestProcessSnapshotUpdate(t *testing.T) { + t.Parallel() + req := WebsocketOrderbookDataHolder{Changes: []WebsocketOrderbookData{{Side: "fakeside", PriceLevel: 1.1, NewQuantity: 2.2}}, ProductID: currency.NewBTCUSD()} + err := e.ProcessSnapshot(&req, time.Time{}) + assert.ErrorIs(t, err, order.ErrSideIsInvalid) + err = e.ProcessUpdate(&req, time.Time{}) + assert.ErrorIs(t, err, order.ErrSideIsInvalid) + req.Changes[0].Side = "offer" + err = e.ProcessSnapshot(&req, time.Now()) + assert.NoError(t, err) + err = e.ProcessUpdate(&req, time.Now()) + assert.NoError(t, err) +} + +func TestGenerateSubscriptions(t *testing.T) { + t.Parallel() + e := new(Exchange) + if err := testexch.Setup(e); err != nil { + log.Fatal(err) + } + e.Websocket.SetCanUseAuthenticatedEndpoints(true) + p1, err := e.GetEnabledPairs(asset.Spot) + require.NoError(t, err) + p2, err := e.GetEnabledPairs(asset.Futures) + require.NoError(t, err) + exp := subscription.List{} + for _, baseSub := range defaultSubscriptions.Enabled() { + s := baseSub.Clone() + s.QualifiedChannel = subscriptionNames[s.Channel] + switch s.Asset { + case asset.Spot: + s.Pairs = p1 + case asset.Futures: + s.Pairs = p2 + case asset.All: + s2 := s.Clone() + s2.Asset = asset.Futures + s2.Pairs = p2 + exp = append(exp, s2) + s.Asset = asset.Spot + s.Pairs = p1 + } + exp = append(exp, s) + } + subs, err := e.generateSubscriptions() + require.NoError(t, err) + testsubs.EqualLists(t, exp, subs) + _, err = subscription.List{{Channel: "wibble"}}.ExpandTemplates(e) + assert.ErrorContains(t, err, "subscription channel not supported: wibble") +} + +func TestSubscribeUnsubscribe(t *testing.T) { + t.Parallel() + sharedtestvalues.SkipTestIfCredentialsUnset(t, e) + req := subscription.List{{Channel: "heartbeat", Asset: asset.Spot, Pairs: currency.Pairs{currency.NewPairWithDelimiter(testCrypto.String(), testFiat.String(), "-")}}} + err := subscribeForTest(t.Context(), e, req) + assert.NoError(t, err) + err = unsubscribeForTest(t.Context(), e, req) + assert.NoError(t, err) +} + +func TestCheckSubscriptions(t *testing.T) { + t.Parallel() + e := &Exchange{ + Base: exchange.Base{ + Config: &config.Exchange{ + Features: &config.FeaturesConfig{ + Subscriptions: subscription.List{ + {Enabled: true, Channel: "matches"}, + }, + }, + }, + Features: exchange.Features{}, + }, + } + e.checkSubscriptions() + testsubs.EqualLists(t, defaultSubscriptions.Enabled(), e.Features.Subscriptions) + testsubs.EqualLists(t, defaultSubscriptions, e.Config.Features.Subscriptions) +} + +func TestCheckWSSequenceAdditionalCoverage(t *testing.T) { + t.Parallel() + ex := new(Exchange) + require.NoError(t, testexch.Setup(ex)) + assert.NoError(t, ex.checkWSSequence(nil, 1)) + conn := testexch.GetMockConn(t, ex, "ws://coinbase-seq") + // first sequence seen sets expected+1 + assert.NoError(t, ex.checkWSSequence(conn, 7)) + // in-order + assert.NoError(t, ex.checkWSSequence(conn, 8)) + // out-of-order resets expected and returns err + err := ex.checkWSSequence(conn, 10) + assert.ErrorIs(t, err, errOutOfSequence) + // resumed should now accept 11 + assert.NoError(t, ex.checkWSSequence(conn, 11)) +} + +func TestGetSubscriptionTemplate(t *testing.T) { + t.Parallel() + ex := new(Exchange) + tpl, err := ex.GetSubscriptionTemplate(nil) + require.NoError(t, err) + require.NotNil(t, tpl) + _, err = template.Must(tpl, nil).Parse("{{ channelName . }}") + assert.NoError(t, err) +} + +func TestManageSubsNilConn(t *testing.T) { + t.Parallel() + ex := new(Exchange) + err := ex.manageSubs(t.Context(), nil, "subscribe", subscription.List{}) + assert.ErrorIs(t, err, websocket.ErrNotConnected) +} + +func TestSubscribeUnsubscribeForConnectionNilConn(t *testing.T) { + t.Parallel() + ex := new(Exchange) + err := ex.subscribeForConnection(t.Context(), nil, subscription.List{}) + assert.ErrorIs(t, err, websocket.ErrNotConnected) + err = ex.unsubscribeForConnection(t.Context(), nil, subscription.List{}) + assert.ErrorIs(t, err, websocket.ErrNotConnected) +} + +func TestGetWSJWTCacheAndRefresh(t *testing.T) { + t.Parallel() + ex := new(Exchange) + // cached token path + ex.jwt.token = "cached" + ex.jwt.expiresAt = time.Now().Add(time.Hour) + tok, err := ex.GetWSJWT(t.Context()) + require.NoError(t, err) + assert.Equal(t, "cached", tok) + + // expired path uses GetJWT; without creds we just assert it returns an error + ex.jwt.expiresAt = time.Now().Add(-time.Second) + _, err = ex.GetWSJWT(t.Context()) + assert.Error(t, err) +} + +func TestProcessBidAskArray(t *testing.T) { + t.Parallel() + snap := &WebsocketOrderbookDataHolder{Changes: []WebsocketOrderbookData{{Side: "bid", PriceLevel: 1.1, NewQuantity: 2.2}, {Side: "offer", PriceLevel: 1.2, NewQuantity: 3.3}}} + bids, asks, err := processBidAskArray(snap, true) + require.NoError(t, err) + assert.Len(t, bids, 1) + assert.Len(t, asks, 1) + + upd := &WebsocketOrderbookDataHolder{Changes: []WebsocketOrderbookData{{Side: "bid", PriceLevel: 1.1, NewQuantity: 2.2}}} + bids, asks, err = processBidAskArray(upd, false) + require.NoError(t, err) + assert.Len(t, bids, 1) + assert.Empty(t, asks) + + bad := &WebsocketOrderbookDataHolder{Changes: []WebsocketOrderbookData{{Side: "wat", PriceLevel: 1.1, NewQuantity: 2.2}}} + _, _, err = processBidAskArray(bad, false) + assert.ErrorIs(t, err, order.ErrSideIsInvalid) +} + +func TestStatusToStandardStatusWebsocket(t *testing.T) { + t.Parallel() + st, err := statusToStandardStatus("PENDING") + require.NoError(t, err) + assert.Equal(t, order.New, st) + _, err = statusToStandardStatus("unknown") + assert.ErrorIs(t, err, order.ErrUnsupportedStatusType) +} + +func TestStringToStandardTypeWebsocket(t *testing.T) { + t.Parallel() + tp, err := stringToStandardType("LIMIT_ORDER_TYPE") + require.NoError(t, err) + assert.Equal(t, order.Limit, tp) + _, err = stringToStandardType("wat") + assert.ErrorIs(t, err, order.ErrUnrecognisedOrderType) +} + +func TestStringToStandardAssetWebsocket(t *testing.T) { + t.Parallel() + at, err := stringToStandardAsset("SPOT") + require.NoError(t, err) + assert.Equal(t, asset.Spot, at) + _, err = stringToStandardAsset("wat") + assert.ErrorIs(t, err, asset.ErrNotSupported) +} + +func TestStrategyDecoderWebsocket(t *testing.T) { + t.Parallel() + tif, err := strategyDecoder("IMMEDIATE_OR_CANCEL") + require.NoError(t, err) + assert.True(t, tif.Is(order.ImmediateOrCancel)) + _, err = strategyDecoder("wat") + assert.ErrorIs(t, err, errUnrecognisedStrategyType) +} + +func TestChannelNameWebsocket(t *testing.T) { + t.Parallel() + name, err := channelName(&subscription.Subscription{Channel: subscription.HeartbeatChannel}) + require.NoError(t, err) + assert.Equal(t, "heartbeats", name) + _, err = channelName(&subscription.Subscription{Channel: "wat"}) + assert.ErrorIs(t, err, subscription.ErrNotSupported) +} + +func TestProcessSnapshotUpdateSendsToOrderbook(t *testing.T) { + t.Parallel() + ex := new(Exchange) + require.NoError(t, testexch.Setup(ex)) + pair := currency.NewBTCUSD() + require.NoError(t, ex.CurrencyPairs.StorePairs(asset.Spot, currency.Pairs{pair}, true)) + ex.pairAliases.Load(map[currency.Pair]currency.Pairs{pair: {pair}}) + snap := WebsocketOrderbookDataHolder{ProductID: pair, Changes: []WebsocketOrderbookData{{Side: "bid", PriceLevel: 1.1, NewQuantity: 2.2}}} + err := ex.ProcessSnapshot(&snap, time.Now()) + assert.NoError(t, err) + upd := WebsocketOrderbookDataHolder{ProductID: pair, Changes: []WebsocketOrderbookData{{Side: "bid", PriceLevel: 1.2, NewQuantity: 1.1}}} + err = ex.ProcessUpdate(&upd, time.Now()) + assert.NoError(t, err) +} + +func receiveDataHandlerPayload(t *testing.T, ex *Exchange) any { + t.Helper() + select { + case payload := <-ex.Websocket.DataHandler.C: + return payload.Data + case <-time.After(time.Second): + t.Fatal("timed out waiting for websocket data handler payload") + return nil + } +} + +func TestWSProcessCandle(t *testing.T) { + t.Parallel() + ex := new(Exchange) + require.NoError(t, testexch.Setup(ex)) + + resp := &StandardWebsocketResponse{ + Timestamp: time.Unix(1704067200, 0), + Events: []byte(`[{ + "type":"update", + "candles":[{ + "start":"1704067200", + "low":"1", + "high":"2", + "open":"1.25", + "close":"1.75", + "volume":"3.5", + "product_id":"BTC-USD" + }] + }]`), + } + require.NoError(t, ex.wsProcessCandle(t.Context(), resp)) + + data := receiveDataHandlerPayload(t, ex) + candles, ok := data.([]kline.Item) + require.True(t, ok) + require.Len(t, candles, 1) + assert.Equal(t, currency.NewPairWithDelimiter("BTC", "USD", "-"), candles[0].Pair) + assert.Equal(t, asset.Spot, candles[0].Asset) + + resp.Events = []byte(`[{"type":false}]`) + assert.Error(t, ex.wsProcessCandle(t.Context(), resp)) +} + +func TestWSProcessMarketTrades(t *testing.T) { + t.Parallel() + ex := new(Exchange) + require.NoError(t, testexch.Setup(ex)) + + resp := &StandardWebsocketResponse{ + Events: []byte(`[{ + "type":"update", + "trades":[{ + "trade_id":"123", + "product_id":"BTC-USD", + "price":"101.2", + "size":"0.5", + "side":"BUY", + "time":"2024-01-01T00:00:00Z" + }] + }]`), + } + require.NoError(t, ex.wsProcessMarketTrades(t.Context(), resp)) + + data := receiveDataHandlerPayload(t, ex) + trades, ok := data.([]trade.Data) + require.True(t, ok) + require.Len(t, trades, 1) + assert.Equal(t, currency.NewPairWithDelimiter("BTC", "USD", "-"), trades[0].CurrencyPair) + assert.Equal(t, order.Buy, trades[0].Side) + + resp.Events = []byte(`[{"type":false}]`) + assert.Error(t, ex.wsProcessMarketTrades(t.Context(), resp)) +} + +func TestWSProcessL2(t *testing.T) { + t.Parallel() + ex := new(Exchange) + require.NoError(t, testexch.Setup(ex)) + + exchangePair := currency.NewPairWithDelimiter("BTC", "USD", "-") + aliasPair := currency.NewBTCUSD() + require.NoError(t, ex.CurrencyPairs.StorePairs(asset.Spot, currency.Pairs{aliasPair}, true)) + ex.pairAliases.Load(map[currency.Pair]currency.Pairs{exchangePair: {aliasPair}}) + + resp := &StandardWebsocketResponse{ + Timestamp: time.Now(), + Events: []byte(`[{ + "type":"snapshot", + "product_id":"BTC-USD", + "updates":[ + {"side":"bid","price_level":"1.1","new_quantity":"2.2"}, + {"side":"offer","price_level":"1.2","new_quantity":"2.3"} + ] + },{ + "type":"update", + "product_id":"BTC-USD", + "updates":[ + {"side":"bid","price_level":"1.15","new_quantity":"1.9"} + ] + }]`), + } + require.NoError(t, ex.wsProcessL2(resp)) + _, err := ex.Websocket.Orderbook.GetOrderbook(aliasPair, asset.Spot) + assert.NoError(t, err) + + resp.Events = []byte(`[{"type":"wat","product_id":"BTC-USD","updates":[]}]`) + assert.ErrorIs(t, ex.wsProcessL2(resp), errUnknownL2DataType) +} + +func TestWSProcessUser(t *testing.T) { + t.Parallel() + ex := new(Exchange) + require.NoError(t, testexch.Setup(ex)) + + resp := &StandardWebsocketResponse{ + Events: []byte(`[{ + "type":"snapshot", + "orders":[{ + "order_type":"LIMIT_ORDER_TYPE", + "order_side":"BUY", + "status":"OPEN", + "avg_price":"100", + "limit_price":"101", + "client_order_id":"cid", + "cumulative_quantity":"0.25", + "leaves_quantity":"0.75", + "order_id":"oid", + "product_id":"BTC-USD", + "product_type":"SPOT", + "stop_price":"0", + "time_in_force":"GOOD_UNTIL_CANCELLED", + "total_fees":"0.1", + "creation_time":"2024-01-01T00:00:00Z", + "end_time":"2024-01-01T01:00:00Z", + "post_only":true + }], + "positions":{ + "perpetual_futures_positions":[{ + "product_id":"BTC-USD", + "position_side":"LONG", + "margin_type":"cross", + "net_size":"1", + "leverage":"2" + }], + "expiring_futures_positions":[{ + "product_id":"BTC-USD", + "side":"SHORT", + "number_of_contracts":"3", + "entry_price":"99" + }] + } + }]`), + } + require.NoError(t, ex.wsProcessUser(t.Context(), resp)) + + data := receiveDataHandlerPayload(t, ex) + orders, ok := data.([]order.Detail) + require.True(t, ok) + require.Len(t, orders, 3) + assert.True(t, orders[0].TimeInForce.Is(order.GoodTillCancel)) + assert.True(t, orders[0].TimeInForce.Is(order.PostOnly)) + assert.Equal(t, asset.Futures, orders[1].AssetType) + + resp.Events = []byte(`[{"type":"snapshot","orders":[{"order_type":"WAT"}]}]`) + assert.ErrorIs(t, ex.wsProcessUser(t.Context(), resp), order.ErrUnrecognisedOrderType) +} + +func subscribeForTest(ctx context.Context, e *Exchange, subs subscription.List) error { + wsRunningURL, err := e.API.Endpoints.GetURL(exchange.WebsocketSpot) + if err != nil { + return err + } + conn, err := e.Websocket.GetConnection(wsRunningURL) + if err != nil { + conn, err = e.Websocket.GetConnection(coinbaseWebsocketURL) + if err != nil { + return err + } + } + return e.subscribeForConnection(ctx, conn, subs) +} + +func unsubscribeForTest(ctx context.Context, e *Exchange, subs subscription.List) error { + wsRunningURL, err := e.API.Endpoints.GetURL(exchange.WebsocketSpot) + if err != nil { + return err + } + conn, err := e.Websocket.GetConnection(wsRunningURL) + if err != nil { + conn, err = e.Websocket.GetConnection(coinbaseWebsocketURL) + if err != nil { + return err + } + } + return e.unsubscribeForConnection(ctx, conn, subs) +} diff --git a/exchanges/coinbase/coinbase_wrapper.go b/exchanges/coinbase/coinbase_wrapper.go index 762664cce1c..d3ab8a10f21 100644 --- a/exchanges/coinbase/coinbase_wrapper.go +++ b/exchanges/coinbase/coinbase_wrapper.go @@ -145,14 +145,9 @@ func (e *Exchange) Setup(exch *config.Exchange) error { } if err := e.Websocket.Setup(&websocket.ManagerSetup{ - ExchangeConfig: exch, - DefaultURL: coinbaseWebsocketURL, - RunningURL: wsRunningURL, - Connector: e.WsConnect, - Subscriber: e.Subscribe, - Unsubscriber: e.Unsubscribe, - GenerateSubscriptions: e.generateSubscriptions, - Features: &e.Features.Supports.WebsocketCapabilities, + ExchangeConfig: exch, + UseMultiConnectionManagement: true, + Features: &e.Features.Supports.WebsocketCapabilities, OrderbookBufferConfig: buffer.Config{ SortBuffer: true, }, @@ -161,8 +156,15 @@ func (e *Exchange) Setup(exch *config.Exchange) error { } return e.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + URL: wsRunningURL, + Connector: e.wsConnect, + Subscriber: e.subscribeForConnection, + Unsubscriber: e.unsubscribeForConnection, + GenerateSubscriptions: e.generateSubscriptions, + Handler: e.wsHandleData, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + MessageFilter: wsRunningURL, }) } diff --git a/exchanges/deribit/deribit_test.go b/exchanges/deribit/deribit_test.go index e71bba2126c..9324d43bd11 100644 --- a/exchanges/deribit/deribit_test.go +++ b/exchanges/deribit/deribit_test.go @@ -129,8 +129,8 @@ func TestUpdateTicker(t *testing.T) { for assetType, cp := range assetTypeToPairsMap { result, err := e.UpdateTicker(t.Context(), cp, assetType) - require.NoErrorf(t, err, "expected nil, got %v for asset type %s pair %s", err, assetType, cp) - require.NotNilf(t, result, "expected result not to be nil for asset type %s pair %s", assetType, cp) + require.NoErrorf(t, err, "request must not error for asset %s pair %s", assetType, cp) + require.NotNilf(t, result, "result must not be nil for asset %s pair %s", assetType, cp) } } @@ -138,7 +138,7 @@ func TestUpdateOrderbook(t *testing.T) { t.Parallel() for assetType, cp := range assetTypeToPairsMap { result, err := e.UpdateOrderbook(t.Context(), cp, assetType) - require.NoErrorf(t, err, "asset type: %v", assetType) + require.NoErrorf(t, err, "request must not error for asset type %v", assetType) require.NotNil(t, result) } } @@ -149,7 +149,7 @@ func TestGetHistoricTrades(t *testing.T) { require.ErrorIs(t, err, asset.ErrNotSupported) for assetType, cp := range map[asset.Item]currency.Pair{asset.Spot: spotTradablePair, asset.Futures: futuresTradablePair} { _, err = e.GetHistoricTrades(t.Context(), cp, assetType, time.Now().Add(-time.Minute*10), time.Now()) - require.NoErrorf(t, err, "asset type: %v", assetType) + require.NoErrorf(t, err, "request must not error for asset type %v", assetType) } } @@ -157,8 +157,8 @@ func TestFetchRecentTrades(t *testing.T) { t.Parallel() for assetType, cp := range assetTypeToPairsMap { result, err := e.GetRecentTrades(t.Context(), cp, assetType) - require.NoErrorf(t, err, "expected nil, got %v for asset type %s pair %s", err, assetType, cp) - require.NotNilf(t, result, "expected result not to be nil for asset type %s pair %s", assetType, cp) + require.NoErrorf(t, err, "request must not error for asset %s pair %s", assetType, cp) + require.NotNilf(t, result, "result must not be nil for asset %s pair %s", assetType, cp) } } @@ -221,8 +221,8 @@ func TestSubmitOrder(t *testing.T) { var info *InstrumentData for assetType, cp := range assetToPairStringMap { info, err = e.GetInstrument(t.Context(), formatPairString(assetType, cp)) - require.NoErrorf(t, err, "expected nil, got %v for asset type %s pair %s", err, assetType, cp) - require.NotNilf(t, result, "expected result not to be nil for asset type %s pair %s", assetType, cp) + require.NoErrorf(t, err, "request must not error for asset %s pair %s", assetType, cp) + require.NotNilf(t, result, "result must not be nil for asset %s pair %s", assetType, cp) result, err = e.SubmitOrder( t.Context(), @@ -236,8 +236,8 @@ func TestSubmitOrder(t *testing.T) { Pair: cp, }, ) - require.NoErrorf(t, err, "expected nil, got %v for asset type %s pair %s", err, assetType, cp) - require.NotNilf(t, result, "expected result not to be nil for asset type %s pair %s", assetType, cp) + require.NoErrorf(t, err, "request must not error for asset %s pair %s", assetType, cp) + require.NotNilf(t, result, "result must not be nil for asset %s pair %s", assetType, cp) } } @@ -259,8 +259,8 @@ func TestGetMarkPriceHistory(t *testing.T) { futureComboPairToString(futureComboTradablePair), } { result, err = e.GetMarkPriceHistory(t.Context(), ps, time.Now().Add(-5*time.Minute), time.Now()) - require.NoErrorf(t, err, "expected nil, got %v for pair %s", err, ps) - require.NotNilf(t, result, "expected result not to be nil for pair %s", ps) + require.NoErrorf(t, err, "request must not error for pair %s", ps) + require.NotNilf(t, result, "result must not be nil for pair %s", ps) } } @@ -277,8 +277,8 @@ func TestWSRetrieveMarkPriceHistory(t *testing.T) { futureComboPairToString(futureComboTradablePair), } { result, err = e.WSRetrieveMarkPriceHistory(t.Context(), ps, time.Now().Add(-4*time.Hour), time.Now()) - require.NoErrorf(t, err, "expected %v, got %v currency pair %v", nil, err, ps) - require.NotNilf(t, result, "expected value not to be nil for pair: %v", ps) + require.NoErrorf(t, err, "request must not error for currency pair %v", ps) + require.NotNilf(t, result, "result must not be nil for pair %v", ps) } } @@ -341,8 +341,8 @@ func TestWSRetrieveBookSummaryByInstrument(t *testing.T) { optionComboPairToString(optionComboTradablePair), } { result, err = e.WSRetrieveBookSummaryByInstrument(t.Context(), ps) - require.NoErrorf(t, err, "expected nil, got %v for pair %s", err, ps) - require.NotNilf(t, result, "expected result not to be nil for pair %s", ps) + require.NoErrorf(t, err, "request must not error for pair %s", ps) + require.NotNilf(t, result, "result must not be nil for pair %s", ps) } } @@ -534,8 +534,8 @@ func TestGetInstrumentData(t *testing.T) { var result *InstrumentData for assetType, cp := range assetTypeToPairsMap { result, err = e.GetInstrument(t.Context(), formatPairString(assetType, cp)) - require.NoErrorf(t, err, "expected nil, got %v for asset type %s pair %s", err, assetType, cp) - require.NotNilf(t, result, "expected result not to be nil for asset type %s pair %s", assetType, cp) + require.NoErrorf(t, err, "request must not error for asset %s pair %s", assetType, cp) + require.NotNilf(t, result, "result must not be nil for asset %s pair %s", assetType, cp) } } @@ -564,7 +564,7 @@ func TestGetInstruments(t *testing.T) { result, err = e.GetInstruments(t.Context(), currency.BTC, "", true) require.NoError(t, err) for a := range result { - require.Falsef(t, result[a].IsActive, "expected expired instrument, but got active instrument %s", result[a].InstrumentName) + require.Falsef(t, result[a].IsActive, "instrument must be expired but was active: %s", result[a].InstrumentName) } } @@ -671,8 +671,8 @@ func TestGetLastTradesByInstrument(t *testing.T) { for assetType, cp := range assetTypeToPairsMap { result, err := e.GetLastTradesByInstrument(t.Context(), formatPairString(assetType, cp), "30500", "31500", "desc", 0, true) - require.NoErrorf(t, err, "expected %v, got %v currency asset %v pair %v", nil, err, assetType, cp) - require.NotNilf(t, result, "expected value not to be nil for asset %v pair: %v", assetType, cp) + require.NoErrorf(t, err, "request must not error for asset %v pair %v", assetType, cp) + require.NotNilf(t, result, "result must not be nil for asset %v pair %v", assetType, cp) } } @@ -683,8 +683,8 @@ func TestWSRetrieveLastTradesByInstrument(t *testing.T) { for assetType, cp := range assetTypeToPairsMap { result, err := e.WSRetrieveLastTradesByInstrument(t.Context(), formatPairString(assetType, cp), "30500", "31500", "desc", 0, true) - require.NoErrorf(t, err, "expected %v, got %v currency asset %v pair %v", nil, err, assetType, cp) - require.NotNilf(t, result, "expected value not to be nil for asset %v pair: %v", assetType, cp) + require.NoErrorf(t, err, "request must not error for asset %v pair %v", assetType, cp) + require.NotNilf(t, result, "result must not be nil for asset %v pair %v", assetType, cp) } } @@ -695,8 +695,8 @@ func TestGetLastTradesByInstrumentAndTime(t *testing.T) { for assetType, cp := range assetTypeToPairsMap { result, err := e.GetLastTradesByInstrumentAndTime(t.Context(), formatPairString(assetType, cp), "", 0, time.Now().Add(-8*time.Hour), time.Now()) - require.NoErrorf(t, err, "expected %v, got %v currency pair %v", nil, err, cp) - require.NotNilf(t, result, "expected value not to be nil for pair: %v", cp) + require.NoErrorf(t, err, "request must not error for currency pair %v", cp) + require.NotNilf(t, result, "result must not be nil for pair %v", cp) } } @@ -707,8 +707,8 @@ func TestWSRetrieveLastTradesByInstrumentAndTime(t *testing.T) { for assetType, cp := range assetTypeToPairsMap { result, err := e.WSRetrieveLastTradesByInstrumentAndTime(t.Context(), formatPairString(assetType, cp), "", 0, true, time.Now().Add(-8*time.Hour), time.Now()) - require.NoErrorf(t, err, "expected %v, got %v currency pair %v", nil, err, cp) - require.NotNilf(t, result, "expected value not to be nil for pair: %v", cp) + require.NoErrorf(t, err, "request must not error for currency pair %v", cp) + require.NotNilf(t, result, "result must not be nil for pair %v", cp) } } @@ -717,7 +717,8 @@ func TestWSProcessTrades(t *testing.T) { e := new(Exchange) require.NoError(t, testexch.Setup(e), "Setup instance must not error") - testexch.FixtureToDataHandler(t, "testdata/wsAllTrades.json", e.wsHandleData) + conn := testexch.GetMockConn(t, e, "") + testexch.FixtureToDataHandler(t, "testdata/wsAllTrades.json", func(ctx context.Context, b []byte) error { return e.wsHandleData(ctx, conn, b) }) e.Websocket.DataHandler.Close() a, p, err := getAssetPairByInstrument("BTC-PERPETUAL") @@ -767,8 +768,8 @@ func TestGetOrderbookData(t *testing.T) { var result *Orderbook for assetType, cp := range assetTypeToPairsMap { result, err = e.GetOrderbook(t.Context(), formatPairString(assetType, cp), 0) - require.NoErrorf(t, err, "expected %v, got %v currency pair %v", nil, err, cp) - require.NotNilf(t, result, "expected value not to be nil for pair: %v", cp) + require.NoErrorf(t, err, "request must not error for currency pair %v", cp) + require.NotNilf(t, result, "result must not be nil for pair %v", cp) } } @@ -783,8 +784,8 @@ func TestWSRetrieveOrderbookData(t *testing.T) { var result *Orderbook for assetType, cp := range assetTypeToPairsMap { result, err = e.WSRetrieveOrderbookData(t.Context(), formatPairString(assetType, cp), 0) - require.NoErrorf(t, err, "expected %v, got %v currency pair %v", nil, err, cp) - require.NotNilf(t, result, "expected value not to be nil for pair: %v", cp) + require.NoErrorf(t, err, "request must not error for currency pair %v", cp) + require.NotNilf(t, result, "result must not be nil for pair %v", cp) } } @@ -3283,7 +3284,7 @@ func setupWs() { if !sharedtestvalues.AreAPICredentialsSet(e) { e.Websocket.SetCanUseAuthenticatedEndpoints(false) } - err := e.WsConnect() + err := e.Websocket.Connect(context.TODO()) if err != nil { log.Fatal(err) } @@ -3418,8 +3419,8 @@ func TestCancelAllOrders(t *testing.T) { orderCancellation.AssetType = assetType orderCancellation.Pair = cp result, err = e.CancelAllOrders(t.Context(), orderCancellation) - require.NoErrorf(t, err, "expected nil, got %v for asset type %s pair %s", err, assetType, cp) - require.NotNilf(t, result, "expected result not to be nil for asset type %s pair %s", assetType, cp) + require.NoErrorf(t, err, "request must not error for asset %s pair %s", assetType, cp) + require.NotNilf(t, result, "result must not be nil for asset %s pair %s", assetType, cp) } } @@ -3428,8 +3429,8 @@ func TestGetOrderInfo(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, e) for assetType, cp := range assetTypeToPairsMap { result, err := e.GetOrderInfo(t.Context(), "1234", cp, assetType) - require.NoErrorf(t, err, "expected nil, got %v for asset type %s pair %s", err, assetType, cp) - require.NotNilf(t, result, "expected result not to be nil for asset type %s pair %s", assetType, cp) + require.NoErrorf(t, err, "request must not error for asset %s pair %s", assetType, cp) + require.NotNilf(t, result, "result must not be nil for asset %s pair %s", assetType, cp) } } @@ -3470,8 +3471,8 @@ func TestGetActiveOrders(t *testing.T) { getOrdersRequest.Pairs = []currency.Pair{cp} getOrdersRequest.AssetType = assetType result, err := e.GetActiveOrders(t.Context(), &getOrdersRequest) - require.NoErrorf(t, err, "expected nil, got %v for asset type %s pair %s", err, assetType, cp) - require.NotNilf(t, result, "expected result not to be nil for asset type %s pair %s", assetType, cp) + require.NoErrorf(t, err, "request must not error for asset %s pair %s", assetType, cp) + require.NotNilf(t, result, "result must not be nil for asset %s pair %s", assetType, cp) } } @@ -3483,8 +3484,8 @@ func TestGetOrderHistory(t *testing.T) { Type: order.AnyType, AssetType: assetType, Side: order.AnySide, Pairs: []currency.Pair{cp}, }) - require.NoErrorf(t, err, "expected nil, got %v for asset type %s pair %s", err, assetType, cp) - require.NotNilf(t, result, "expected result not to be nil for asset type %s pair %s", assetType, cp) + require.NoErrorf(t, err, "request must not error for asset %s pair %s", assetType, cp) + require.NotNilf(t, result, "result must not be nil for asset %s pair %s", assetType, cp) } } @@ -3492,8 +3493,8 @@ func TestGetAssetPairByInstrument(t *testing.T) { t.Parallel() for _, assetType := range []asset.Item{asset.Spot, asset.Futures, asset.Options, asset.OptionCombo, asset.FutureCombo} { availablePairs, err := e.GetAvailablePairs(assetType) - require.NoErrorf(t, err, "expected nil, got %v for asset type %s", err, assetType) - require.NotNilf(t, availablePairs, "expected result not to be nil for asset type %s", assetType) + require.NoErrorf(t, err, "request must not error for asset type %s", assetType) + require.NotNilf(t, availablePairs, "available pairs must not be nil for asset type %s", assetType) for _, cp := range availablePairs { instrument := formatPairString(assetType, cp) t.Run(fmt.Sprintf("%s %s", assetType, instrument), func(t *testing.T) { @@ -3572,9 +3573,9 @@ func TestGetFeeByTypeOfflineTradeFee(t *testing.T) { require.NoError(t, err) require.NotNil(t, result) if !sharedtestvalues.AreAPICredentialsSet(e) { - assert.Equalf(t, exchange.OfflineTradeFee, feeBuilder.FeeType, "expected %v, received %v", exchange.OfflineTradeFee, feeBuilder.FeeType) + assert.Equalf(t, exchange.OfflineTradeFee, feeBuilder.FeeType, "fee type should match expected value; expected %v, got %v", exchange.OfflineTradeFee, feeBuilder.FeeType) } else { - assert.Equalf(t, exchange.CryptocurrencyTradeFee, feeBuilder.FeeType, "expected %v, received %v", exchange.CryptocurrencyTradeFee, feeBuilder.FeeType) + assert.Equalf(t, exchange.CryptocurrencyTradeFee, feeBuilder.FeeType, "fee type should match expected value; expected %v, got %v", exchange.CryptocurrencyTradeFee, feeBuilder.FeeType) } } @@ -3591,14 +3592,14 @@ func TestCalculateTradingFee(t *testing.T) { result, err := calculateTradingFee(feeBuilder) require.NoError(t, err) require.NotNil(t, result) - require.Equalf(t, 1e-1, result, "expected result %f, got %f", 1e-1, result) + require.Equalf(t, 1e-1, result, "result must equal %f; got %f", 1e-1, result) // futures feeBuilder.Pair, err = currency.NewPairFromString("BTC-21OCT22") require.NoError(t, err) result, err = calculateTradingFee(feeBuilder) require.NoError(t, err) require.NotNil(t, result) - require.Equalf(t, 0.1, result, "expected 0.1 but found %f", result) + require.Equalf(t, 0.1, result, "result must equal 0.1; got %f", result) // options feeBuilder.Pair, err = currency.NewPairFromString("SOL-21OCT22-20-C") require.NoError(t, err) @@ -3606,7 +3607,7 @@ func TestCalculateTradingFee(t *testing.T) { result, err = calculateTradingFee(feeBuilder) require.NoError(t, err) require.NotNil(t, result) - require.Equalf(t, 0.3, result, "expected 0.3 but found %f", result) + require.Equalf(t, 0.3, result, "result must equal 0.3; got %f", result) // options feeBuilder.Pair, err = currency.NewPairFromString("SOL-21OCT22-20-C,SOL-21OCT22-20-P") require.NoError(t, err) @@ -3614,7 +3615,7 @@ func TestCalculateTradingFee(t *testing.T) { _, err = calculateTradingFee(feeBuilder) require.NoError(t, err) require.NotNil(t, result) - require.Equalf(t, 0.3, result, "expected 0.3 but found %f", result) + require.Equalf(t, 0.3, result, "result must equal 0.3; got %f", result) // option_combo feeBuilder.Pair, err = currency.NewPairFromString("BTC-STRG-21OCT22-19000_21000") require.NoError(t, err) @@ -3723,7 +3724,7 @@ func TestProcessPushData(t *testing.T) { for k, v := range websocketPushData { t.Run(k, func(t *testing.T) { t.Parallel() - err := e.wsHandleData(t.Context(), []byte(v)) + err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), []byte(v)) require.NoError(t, err, "wsHandleData must not error") }) } @@ -4138,8 +4139,8 @@ func TestGetCurrencyTradeURL(t *testing.T) { for _, a := range e.GetAssetTypes(false) { var pairs currency.Pairs pairs, err = e.CurrencyPairs.GetPairs(a, false) - require.NoErrorf(t, err, "cannot get pairs for %s", a) - require.NotEmptyf(t, pairs, "no pairs for %s", a) + require.NoErrorf(t, err, "GetPairs must not error for asset %s", a) + require.NotEmptyf(t, pairs, "pairs must not be empty for asset %s", a) var resp string resp, err = e.GetCurrencyTradeURL(t.Context(), a, pairs[0]) require.NoError(t, err) @@ -4244,7 +4245,7 @@ func TestTimeInForceFromString(t *testing.T) { t.Parallel() for i := range timeInForceList { result, err := timeInForceFromString(timeInForceList[i].String, timeInForceList[i].PostOnly) - assert.Equalf(t, timeInForceList[i].TIF, result, "expected %s, got %s", timeInForceList[i].TIF.String(), result.String()) + assert.Equalf(t, timeInForceList[i].TIF, result, "value should match expected TIF; expected %s, got %s", timeInForceList[i].TIF.String(), result.String()) require.ErrorIs(t, err, timeInForceList[i].Error) } } diff --git a/exchanges/deribit/deribit_websocket.go b/exchanges/deribit/deribit_websocket.go index b3ef0f36bd7..2ad583709cc 100644 --- a/exchanges/deribit/deribit_websocket.go +++ b/exchanges/deribit/deribit_websocket.go @@ -103,29 +103,15 @@ var defaultSubscriptions = subscription.List{ {Enabled: true, Asset: asset.All, Channel: subscription.MyTradesChannel, Interval: kline.HundredMilliseconds, Authenticated: true}, } -// WsConnect starts a new connection with the websocket API -func (e *Exchange) WsConnect() error { - ctx := context.TODO() - if !e.Websocket.IsEnabled() || !e.IsEnabled() { - return websocket.ErrWebsocketNotEnabled - } - var dialer gws.Dialer - if err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{}, nil); err != nil { +func (e *Exchange) wsConnect(ctx context.Context, conn websocket.Connection) error { + if err := conn.Dial(ctx, &gws.Dialer{}, http.Header{}, nil); err != nil { return err } - e.Websocket.Wg.Add(1) - go e.wsReadData(ctx) - go e.wsStartHeartbeat(ctx) - if e.Websocket.CanUseAuthenticatedEndpoints() { - if err := e.wsLogin(ctx); err != nil { - log.Errorf(log.ExchangeSys, "%v - authentication failed: %v\n", e.Name, err) - e.Websocket.SetCanUseAuthenticatedEndpoints(false) - } - } + go e.wsStartHeartbeat(ctx, conn) return nil } -func (e *Exchange) wsStartHeartbeat(ctx context.Context) { +func (e *Exchange) wsStartHeartbeat(ctx context.Context, conn websocket.Connection) { msg := wsInput{ ID: e.MessageID(), JSONRPCVersion: rpcVersion, @@ -134,7 +120,7 @@ func (e *Exchange) wsStartHeartbeat(ctx context.Context) { "interval": 15, }, } - respRaw, err := e.Websocket.Conn.SendMessageReturnResponse(ctx, request.Unset, msg.ID, msg) + respRaw, err := conn.SendMessageReturnResponse(ctx, request.Unset, msg.ID, msg) if err != nil { log.Errorf(log.ExchangeSys, "%v %s: %s\n", e.Name, errStartingHeartbeat, err) return @@ -148,10 +134,7 @@ func (e *Exchange) wsStartHeartbeat(ctx context.Context) { } } -func (e *Exchange) wsLogin(ctx context.Context) error { - if !e.IsWebsocketAuthenticationSupported() { - return fmt.Errorf("%v AuthenticatedWebsocketAPISupport not enabled", e.Name) - } +func (e *Exchange) wsAuthenticate(ctx context.Context, conn websocket.Connection) error { creds, err := e.GetCredentials(ctx) if err != nil { return err @@ -177,7 +160,7 @@ func (e *Exchange) wsLogin(ctx context.Context) error { "signature": hex.EncodeToString(hmac), }, } - resp, err := e.Websocket.Conn.SendMessageReturnResponse(ctx, request.Unset, req.ID, req) + resp, err := conn.SendMessageReturnResponse(ctx, request.Unset, req.ID, req) if err != nil { e.Websocket.SetCanUseAuthenticatedEndpoints(false) return err @@ -193,38 +176,21 @@ func (e *Exchange) wsLogin(ctx context.Context) error { return nil } -// wsReadData receives and passes on websocket messages for processing -func (e *Exchange) wsReadData(ctx context.Context) { - defer e.Websocket.Wg.Done() - - for { - resp := e.Websocket.Conn.ReadMessage() - if resp.Raw == nil { - return - } - if err := e.wsHandleData(ctx, resp.Raw); err != nil { - if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { - log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) - } - } - } -} - -func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { +func (e *Exchange) wsHandleData(ctx context.Context, conn websocket.Connection, respRaw []byte) error { var response wsResponse err := json.Unmarshal(respRaw, &response) if err != nil { return fmt.Errorf("%s - err %s could not parse websocket data: %s", e.Name, err, respRaw) } if response.Method == "heartbeat" { - go e.wsSendHeartbeat(ctx) + go e.wsSendHeartbeat(ctx, conn) return nil } if response.ID != "" { if strings.HasPrefix(response.ID, "hb-") { return nil } - return e.Websocket.Match.RequireMatchWithData(response.ID, respRaw) + return conn.RequireMatchWithData(response.ID, respRaw) } channels := strings.Split(response.Params.Channel, ".") switch channels[0] { @@ -318,13 +284,13 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { return nil } -func (e *Exchange) wsSendHeartbeat(ctx context.Context) { +func (e *Exchange) wsSendHeartbeat(ctx context.Context, conn websocket.Connection) { msg := WsSubscriptionInput{ ID: "hb-" + e.MessageID(), JSONRPCVersion: rpcVersion, Method: "public/test", } - if err := e.Websocket.Conn.SendJSONMessage(ctx, request.Unset, msg); err != nil { + if err := conn.SendJSONMessage(ctx, request.Unset, msg); err != nil { log.Errorf(log.ExchangeSys, "%v %s: %s\n", e.Name, errSendingHeartbeat, err) } } @@ -841,21 +807,17 @@ func (e *Exchange) GetSubscriptionTemplate(_ *subscription.Subscription) (*templ Parse(subTplText) } -// Subscribe sends a websocket message to receive data from the channel -func (e *Exchange) Subscribe(subs subscription.List) error { - ctx := context.TODO() - errs := e.handleSubscription(ctx, "public/subscribe", subs.Public()) - return common.AppendError(errs, e.handleSubscription(ctx, "private/subscribe", subs.Private())) +func (e *Exchange) subscribeForConnection(ctx context.Context, conn websocket.Connection, subs subscription.List) error { + errs := e.handleSubscription(ctx, conn, "public/subscribe", subs.Public()) + return common.AppendError(errs, e.handleSubscription(ctx, conn, "private/subscribe", subs.Private())) } -// Unsubscribe sends a websocket message to stop receiving data from the channel -func (e *Exchange) Unsubscribe(subs subscription.List) error { - ctx := context.TODO() - errs := e.handleSubscription(ctx, "public/unsubscribe", subs.Public()) - return common.AppendError(errs, e.handleSubscription(ctx, "private/unsubscribe", subs.Private())) +func (e *Exchange) unsubscribeForConnection(ctx context.Context, conn websocket.Connection, subs subscription.List) error { + errs := e.handleSubscription(ctx, conn, "public/unsubscribe", subs.Public()) + return common.AppendError(errs, e.handleSubscription(ctx, conn, "private/unsubscribe", subs.Private())) } -func (e *Exchange) handleSubscription(ctx context.Context, method string, subs subscription.List) error { +func (e *Exchange) handleSubscription(ctx context.Context, conn websocket.Connection, method string, subs subscription.List) error { var err error subs, err = subs.ExpandTemplates(e) if err != nil || len(subs) == 0 { @@ -869,7 +831,7 @@ func (e *Exchange) handleSubscription(ctx context.Context, method string, subs s Params: map[string][]string{"channels": subs.QualifiedChannels()}, } - data, err := e.Websocket.Conn.SendMessageReturnResponse(ctx, request.Unset, r.ID, r) + data, err := conn.SendMessageReturnResponse(ctx, request.Unset, r.ID, r) if err != nil { return err } @@ -890,9 +852,9 @@ func (e *Exchange) handleSubscription(ctx context.Context, method string, subs s if _, ok := subAck[s.QualifiedChannel]; ok { delete(subAck, s.QualifiedChannel) if !strings.Contains(method, "unsubscribe") { - err = common.AppendError(err, e.Websocket.AddSuccessfulSubscriptions(e.Websocket.Conn, s)) + err = common.AppendError(err, e.Websocket.AddSuccessfulSubscriptions(conn, s)) } else { - err = common.AppendError(err, e.Websocket.RemoveSubscriptions(e.Websocket.Conn, s)) + err = common.AppendError(err, e.Websocket.RemoveSubscriptions(conn, s)) } } else { err = common.AppendError(err, errors.New(s.String()+" failed to "+method)) diff --git a/exchanges/deribit/deribit_wrapper.go b/exchanges/deribit/deribit_wrapper.go index 40e66896fac..8306f020767 100644 --- a/exchanges/deribit/deribit_wrapper.go +++ b/exchanges/deribit/deribit_wrapper.go @@ -147,6 +147,7 @@ func (e *Exchange) SetDefaults() { exchange.RestFutures: "https://www.deribit.com", exchange.RestSpot: "https://www.deribit.com", exchange.RestSpotSupplementary: "https://test.deribit.com", + exchange.WebsocketSpot: "wss://www.deribit.com/ws/api/v2", }) if err != nil { log.Errorln(log.ExchangeSys, err) @@ -172,14 +173,10 @@ func (e *Exchange) Setup(exch *config.Exchange) error { return err } err = e.Websocket.Setup(&websocket.ManagerSetup{ - ExchangeConfig: exch, - DefaultURL: deribitWebsocketAddress, - RunningURL: deribitWebsocketAddress, - Connector: e.WsConnect, - Subscriber: e.Subscribe, - Unsubscriber: e.Unsubscribe, - GenerateSubscriptions: e.generateSubscriptions, - Features: &e.Features.Supports.WebsocketCapabilities, + ExchangeConfig: exch, + UseMultiConnectionManagement: true, + Features: &e.Features.Supports.WebsocketCapabilities, + MaxWebsocketSubscriptionsPerConnection: 500, // https://docs.deribit.com/ (max 500 channels per subscribe request) OrderbookBufferConfig: buffer.Config{ SortBuffer: true, SortBufferByUpdateIDs: true, @@ -189,10 +186,22 @@ func (e *Exchange) Setup(exch *config.Exchange) error { return err } + wsRunningURL, err := e.API.Endpoints.GetURL(exchange.WebsocketSpot) + if err != nil { + return err + } + return e.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ - URL: e.Websocket.GetWebsocketURL(), - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + URL: wsRunningURL, + Connector: e.wsConnect, + Authenticate: e.wsAuthenticate, + Subscriber: e.subscribeForConnection, + Unsubscriber: e.unsubscribeForConnection, + GenerateSubscriptions: e.generateSubscriptions, + Handler: e.wsHandleData, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + MessageFilter: wsRunningURL, }) } @@ -1123,11 +1132,6 @@ func (e *Exchange) GetServerTime(ctx context.Context, _ asset.Item) (time.Time, return e.GetTime(ctx) } -// AuthenticateWebsocket sends an authentication message to the websocket -func (e *Exchange) AuthenticateWebsocket(ctx context.Context) error { - return e.wsLogin(ctx) -} - // GetFuturesContractDetails returns all contracts from the exchange by asset type func (e *Exchange) GetFuturesContractDetails(ctx context.Context, item asset.Item) ([]futures.Contract, error) { if !item.IsFutures() { diff --git a/exchanges/deribit/deribit_ws_endpoints.go b/exchanges/deribit/deribit_ws_endpoints.go index fe1f1713fd0..960614ed14f 100644 --- a/exchanges/deribit/deribit_ws_endpoints.go +++ b/exchanges/deribit/deribit_ws_endpoints.go @@ -2293,8 +2293,12 @@ func (e *Exchange) sendWsPayload(ctx context.Context, ep request.EndpointLimit, if e.Verbose { log.Debugf(log.RequestSys, "%s attempt %d", e.Name, attempt) } + conn, err := e.Websocket.GetConnection(deribitWebsocketAddress) + if err != nil { + return err + } var payload []byte - payload, err = e.Websocket.Conn.SendMessageReturnResponse(ctx, request.Unset, input.ID, input) + payload, err = conn.SendMessageReturnResponse(ctx, request.Unset, input.ID, input) if err != nil { return err } diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 2e830cafd0f..b5da9e19d9f 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -23,6 +23,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/encoding/json" "github.com/thrasher-corp/gocryptotrader/exchange/accounts" "github.com/thrasher-corp/gocryptotrader/exchange/websocket" + exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/fill" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" @@ -104,10 +105,6 @@ func (e *Exchange) WsConnectSpot(ctx context.Context, conn websocket.Connection) // websocketLogin authenticates the websocket connection func (e *Exchange) websocketLogin(ctx context.Context, conn websocket.Connection, channel string) error { - if conn == nil { - return fmt.Errorf("%w: %T", common.ErrNilPointer, conn) - } - if channel == "" { return errChannelEmpty } @@ -984,6 +981,11 @@ func (e *Exchange) SendWebsocketRequest(ctx context.Context, epl request.Endpoin if err != nil { return err } + if a, ok := connSignature.(asset.Item); ok { + if connSignature, err = e.websocketConnectionSignature(a); err != nil { + return err + } + } conn, err := e.Websocket.GetConnection(connSignature) if err != nil { @@ -1030,6 +1032,23 @@ func (e *Exchange) SendWebsocketRequest(ctx context.Context, epl request.Endpoin return json.Unmarshal(inbound.Data, &resultHolder{Result: result}) } +func (e *Exchange) websocketConnectionSignature(a asset.Item) (string, error) { + switch a { + case asset.Spot: + return e.API.Endpoints.GetURL(exchange.WebsocketSpot) + case asset.USDTMarginedFutures: + return e.API.Endpoints.GetURL(exchange.WebsocketUSDTMargined) + case asset.CoinMarginedFutures: + return e.API.Endpoints.GetURL(exchange.WebsocketCoinMargined) + case asset.DeliveryFutures: + return e.API.Endpoints.GetURL(exchange.WebsocketSpotSupplementary) + case asset.Options: + return e.API.Endpoints.GetURL(exchange.WebsocketOptions) + default: + return "", fmt.Errorf("%w: websocket connection signature asset %q", asset.ErrNotSupported, a) + } +} + type wsRespAckInspector struct{} // IsFinal checks the payload for an ack, it returns true if the payload does not contain an ack. diff --git a/exchanges/gateio/gateio_websocket_futures.go b/exchanges/gateio/gateio_websocket_futures.go index da2a1e6303a..f417c159240 100644 --- a/exchanges/gateio/gateio_websocket_futures.go +++ b/exchanges/gateio/gateio_websocket_futures.go @@ -15,6 +15,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/encoding/json" "github.com/thrasher-corp/gocryptotrader/exchange/accounts" "github.com/thrasher-corp/gocryptotrader/exchange/websocket" + exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/fill" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" @@ -72,7 +73,11 @@ var errNoChannelsSupplied = errors.New("no channels supplied") // WsFuturesConnect initiates a websocket connection for futures account func (e *Exchange) WsFuturesConnect(ctx context.Context, conn websocket.Connection) error { a := asset.USDTMarginedFutures - if conn.GetURL() == btcFuturesWebsocketURL { + wsCoinMarginedURL, err := e.API.Endpoints.GetURL(exchange.WebsocketCoinMargined) + if err != nil { + return err + } + if conn.GetURL() == wsCoinMarginedURL { a = asset.CoinMarginedFutures } if err := e.CurrencyPairs.IsAssetEnabled(a); err != nil { diff --git a/exchanges/gateio/gateio_websocket_request_spot_test.go b/exchanges/gateio/gateio_websocket_request_spot_test.go index ebcb6f7facc..8dfee901ce4 100644 --- a/exchanges/gateio/gateio_websocket_request_spot_test.go +++ b/exchanges/gateio/gateio_websocket_request_spot_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" + exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" @@ -22,7 +23,9 @@ func TestWebsocketLogin(t *testing.T) { e := newExchangeWithWebsocket(t, asset.Spot) - c, err := e.Websocket.GetConnection(asset.Spot) + wsSpotURL, err := e.API.Endpoints.GetURL(exchange.WebsocketSpot) + require.NoError(t, err) + c, err := e.Websocket.GetConnection(wsSpotURL) require.NoError(t, err) err = e.websocketLogin(t.Context(), c, "") diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index b902c826e53..5fab041b53a 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -168,10 +168,14 @@ func (e *Exchange) SetDefaults() { } e.API.Endpoints = e.NewEndpoints() err = e.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ - exchange.RestSpot: gateioTradeURL, - exchange.RestFutures: gateioFuturesLiveTradingAlternative, - exchange.RestSpotSupplementary: gateioFuturesTestnetTrading, - exchange.WebsocketSpot: gateioWebsocketEndpoint, + exchange.RestSpot: gateioTradeURL, + exchange.RestFutures: gateioFuturesLiveTradingAlternative, + exchange.RestSpotSupplementary: gateioFuturesTestnetTrading, + exchange.WebsocketSpot: gateioWebsocketEndpoint, + exchange.WebsocketUSDTMargined: usdtFuturesWebsocketURL, + exchange.WebsocketCoinMargined: btcFuturesWebsocketURL, + exchange.WebsocketSpotSupplementary: deliveryRealUSDTTradingURL, + exchange.WebsocketOptions: optionsWebsocketURL, }) if err != nil { log.Errorln(log.ExchangeSys, err) @@ -210,9 +214,30 @@ func (e *Exchange) Setup(exch *config.Exchange) error { if err != nil { return err } + wsSpotURL, err := e.API.Endpoints.GetURL(exchange.WebsocketSpot) + if err != nil { + return err + } + wsUSDTFuturesURL, err := e.API.Endpoints.GetURL(exchange.WebsocketUSDTMargined) + if err != nil { + return err + } + wsCoinFuturesURL, err := e.API.Endpoints.GetURL(exchange.WebsocketCoinMargined) + if err != nil { + return err + } + wsDeliveryURL, err := e.API.Endpoints.GetURL(exchange.WebsocketSpotSupplementary) + if err != nil { + return err + } + wsOptionsURL, err := e.API.Endpoints.GetURL(exchange.WebsocketOptions) + if err != nil { + return err + } + // Spot connection err = e.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ - URL: gateioWebsocketEndpoint, + URL: wsSpotURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, Handler: e.WsHandleSpotData, @@ -221,14 +246,14 @@ func (e *Exchange) Setup(exch *config.Exchange) error { GenerateSubscriptions: e.generateSubscriptionsSpot, Connector: e.WsConnectSpot, Authenticate: e.authenticateSpot, - MessageFilter: asset.Spot, + MessageFilter: wsSpotURL, }) if err != nil { return err } // Futures connection - USDT margined err = e.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ - URL: usdtFuturesWebsocketURL, + URL: wsUSDTFuturesURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, Handler: func(ctx context.Context, conn websocket.Connection, incoming []byte) error { @@ -241,7 +266,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { }, Connector: e.WsFuturesConnect, Authenticate: e.authenticateFutures, - MessageFilter: asset.USDTMarginedFutures, + MessageFilter: wsUSDTFuturesURL, }) if err != nil { return err @@ -249,7 +274,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { // Futures connection - BTC margined err = e.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ - URL: btcFuturesWebsocketURL, + URL: wsCoinFuturesURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, Handler: func(ctx context.Context, conn websocket.Connection, incoming []byte) error { @@ -261,7 +286,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { return e.GenerateFuturesDefaultSubscriptions(asset.CoinMarginedFutures) }, Connector: e.WsFuturesConnect, - MessageFilter: asset.CoinMarginedFutures, + MessageFilter: wsCoinFuturesURL, }) if err != nil { return err @@ -270,7 +295,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { // TODO: Add BTC margined delivery futures. // Futures connection - Delivery - USDT margined err = e.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ - URL: deliveryRealUSDTTradingURL, + URL: wsDeliveryURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, Handler: func(ctx context.Context, conn websocket.Connection, incoming []byte) error { @@ -280,7 +305,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { Unsubscriber: e.DeliveryFuturesUnsubscribe, GenerateSubscriptions: e.GenerateDeliveryFuturesDefaultSubscriptions, Connector: e.WsDeliveryFuturesConnect, - MessageFilter: asset.DeliveryFutures, + MessageFilter: wsDeliveryURL, }) if err != nil { return err @@ -288,7 +313,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { // Futures connection - Options return e.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ - URL: optionsWebsocketURL, + URL: wsOptionsURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, Handler: e.WsHandleOptionsData, @@ -296,7 +321,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { Unsubscriber: e.OptionsUnsubscribe, GenerateSubscriptions: e.GenerateOptionsDefaultSubscriptions, Connector: e.WsOptionsConnect, - MessageFilter: asset.Options, + MessageFilter: wsOptionsURL, }) } diff --git a/exchanges/gemini/gemini_test.go b/exchanges/gemini/gemini_test.go index baf6858b061..0259918c11e 100644 --- a/exchanges/gemini/gemini_test.go +++ b/exchanges/gemini/gemini_test.go @@ -6,7 +6,6 @@ import ( "testing" "time" - gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" @@ -37,80 +36,58 @@ var e *Exchange func TestGetSymbols(t *testing.T) { t.Parallel() _, err := e.GetSymbols(t.Context()) - if err != nil { - t.Error("GetSymbols() error", err) - } + assert.NoError(t, err, "GetSymbols should not error") } func TestFetchTradablePairs(t *testing.T) { t.Parallel() pairs, err := e.FetchTradablePairs(t.Context(), asset.Spot) - if err != nil { - t.Fatal(err) - } - if !pairs.Contains(currency.NewPair(currency.STORJ, currency.USD), false) { - t.Error("expected pair STORJ-USD") - } - if !pairs.Contains(currency.NewBTCUSD(), false) { - t.Error("expected pair BTC-USD") - } - if !pairs.Contains(currency.NewPair(currency.AAVE, currency.USD), false) { - t.Error("expected pair AAVE-BTC") - } + require.NoError(t, err) + assert.True(t, pairs.Contains(currency.NewPair(currency.STORJ, currency.USD), false), "tradable pairs should contain STORJ-USD") + assert.True(t, pairs.Contains(currency.NewBTCUSD(), false), "tradable pairs should contain BTC-USD") + assert.True(t, pairs.Contains(currency.NewPair(currency.AAVE, currency.USD), false), "tradable pairs should contain AAVE-USD") } func TestGetTicker(t *testing.T) { t.Parallel() _, err := e.GetTicker(t.Context(), "BTCUSD") - if err != nil { - t.Error("GetTicker() error", err) - } + assert.NoError(t, err, "GetTicker should not error") _, err = e.GetTicker(t.Context(), "bla") - if err == nil { - t.Error("GetTicker() Expected error") - } + assert.Error(t, err, "GetTicker should error for invalid symbol") } func TestGetOrderbook(t *testing.T) { t.Parallel() _, err := e.GetOrderbook(t.Context(), testCurrency, url.Values{}) - if err != nil { - t.Error("GetOrderbook() error", err) - } + assert.NoError(t, err, "GetOrderbook should not error") } func TestGetTrades(t *testing.T) { t.Parallel() _, err := e.GetTrades(t.Context(), testCurrency, 0, 0, false) - if err != nil { - t.Error("GetTrades() error", err) - } + assert.NoError(t, err, "GetTrades should not error") } func TestGetNotionalVolume(t *testing.T) { t.Parallel() _, err := e.GetNotionalVolume(t.Context()) if err != nil && mockTests { - t.Error("GetNotionalVolume() error", err) + assert.NoError(t, err, "GetNotionalVolume should not error in mock mode") } else if err == nil && !mockTests { - t.Error("GetNotionalVolume() error cannot be nil") + assert.Error(t, err, "GetNotionalVolume should error when credentials are unset") } } func TestGetAuction(t *testing.T) { t.Parallel() _, err := e.GetAuction(t.Context(), testCurrency) - if err != nil { - t.Error("GetAuction() error", err) - } + assert.NoError(t, err, "GetAuction should not error") } func TestGetAuctionHistory(t *testing.T) { t.Parallel() _, err := e.GetAuctionHistory(t.Context(), testCurrency, url.Values{}) - if err != nil { - t.Error("GetAuctionHistory() error", err) - } + assert.NoError(t, err, "GetAuctionHistory should not error") } func TestNewOrder(t *testing.T) { @@ -122,9 +99,9 @@ func TestNewOrder(t *testing.T) { order.Sell.Lower(), "exchange limit") if err != nil && mockTests { - t.Error("NewOrder() error", err) + assert.NoError(t, err, "NewOrder should not error in mock mode") } else if err == nil && !mockTests { - t.Error("NewOrder() error cannot be nil") + assert.Error(t, err, "NewOrder should error when credentials are unset") } } @@ -132,9 +109,9 @@ func TestCancelExistingOrder(t *testing.T) { t.Parallel() _, err := e.CancelExistingOrder(t.Context(), 265555413) if err != nil && mockTests { - t.Error("CancelExistingOrder() error", err) + assert.NoError(t, err, "CancelExistingOrder should not error in mock mode") } else if err == nil && !mockTests { - t.Error("CancelExistingOrder() error cannot be nil") + assert.Error(t, err, "CancelExistingOrder should error when credentials are unset") } } @@ -142,9 +119,9 @@ func TestCancelExistingOrders(t *testing.T) { t.Parallel() _, err := e.CancelExistingOrders(t.Context(), false) if err != nil && mockTests { - t.Error("CancelExistingOrders() error", err) + assert.NoError(t, err, "CancelExistingOrders should not error in mock mode") } else if err == nil && !mockTests { - t.Error("CancelExistingOrders() error cannot be nil") + assert.Error(t, err, "CancelExistingOrders should error when credentials are unset") } } @@ -152,9 +129,9 @@ func TestGetOrderStatus(t *testing.T) { t.Parallel() _, err := e.GetOrderStatus(t.Context(), 265563260) if err != nil && mockTests { - t.Error("GetOrderStatus() error", err) + assert.NoError(t, err, "GetOrderStatus should not error in mock mode") } else if err == nil && !mockTests { - t.Error("GetOrderStatus() error cannot be nil") + assert.Error(t, err, "GetOrderStatus should error when credentials are unset") } } @@ -162,9 +139,9 @@ func TestGetOrders(t *testing.T) { t.Parallel() _, err := e.GetOrders(t.Context()) if err != nil && mockTests { - t.Error("GetOrders() error", err) + assert.NoError(t, err, "GetOrders should not error in mock mode") } else if err == nil && !mockTests { - t.Error("GetOrders() error cannot be nil") + assert.Error(t, err, "GetOrders should error when credentials are unset") } } @@ -172,9 +149,9 @@ func TestGetTradeHistory(t *testing.T) { t.Parallel() _, err := e.GetTradeHistory(t.Context(), testCurrency, 0) if err != nil && mockTests { - t.Error("GetTradeHistory() error", err) + assert.NoError(t, err, "GetTradeHistory should not error in mock mode") } else if err == nil && !mockTests { - t.Error("GetTradeHistory() error cannot be nil") + assert.Error(t, err, "GetTradeHistory should error when credentials are unset") } } @@ -182,9 +159,9 @@ func TestGetTradeVolume(t *testing.T) { t.Parallel() _, err := e.GetTradeVolume(t.Context()) if err != nil && mockTests { - t.Error("GetTradeVolume() error", err) + assert.NoError(t, err, "GetTradeVolume should not error in mock mode") } else if err == nil && !mockTests { - t.Error("GetTradeVolume() error cannot be nil") + assert.Error(t, err, "GetTradeVolume should error when credentials are unset") } } @@ -192,35 +169,31 @@ func TestGetBalances(t *testing.T) { t.Parallel() _, err := e.GetBalances(t.Context()) if err != nil && mockTests { - t.Error("GetBalances() error", err) + assert.NoError(t, err, "GetBalances should not error in mock mode") } else if err == nil && !mockTests { - t.Error("GetBalances() error cannot be nil") + assert.Error(t, err, "GetBalances should error when credentials are unset") } } func TestGetCryptoDepositAddress(t *testing.T) { t.Parallel() _, err := e.GetCryptoDepositAddress(t.Context(), "LOL123", "btc") - if err == nil { - t.Error("GetCryptoDepositAddress() Expected error") - } + assert.Error(t, err, "GetCryptoDepositAddress should error for invalid account") } func TestWithdrawCrypto(t *testing.T) { t.Parallel() _, err := e.WithdrawCrypto(t.Context(), "LOL123", "btc", 1) - if err == nil { - t.Error("WithdrawCrypto() Expected error") - } + assert.Error(t, err, "WithdrawCrypto should error for invalid account") } func TestPostHeartbeat(t *testing.T) { t.Parallel() _, err := e.PostHeartbeat(t.Context()) if err != nil && mockTests { - t.Error("PostHeartbeat() error", err) + assert.NoError(t, err, "PostHeartbeat should not error in mock mode") } else if err == nil && !mockTests { - t.Error("PostHeartbeat() error cannot be nil") + assert.Error(t, err, "PostHeartbeat should error when credentials are unset") } } @@ -241,21 +214,11 @@ func TestGetFeeByTypeOfflineTradeFee(t *testing.T) { t.Parallel() feeBuilder := setFeeBuilder() _, err := e.GetFeeByType(t.Context(), feeBuilder) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if !sharedtestvalues.AreAPICredentialsSet(e) { - if feeBuilder.FeeType != exchange.OfflineTradeFee { - t.Errorf("Expected %v, received %v", - exchange.OfflineTradeFee, - feeBuilder.FeeType) - } + assert.Equal(t, exchange.OfflineTradeFee, feeBuilder.FeeType) } else { - if feeBuilder.FeeType != exchange.CryptocurrencyTradeFee { - t.Errorf("Expected %v, received %v", - exchange.CryptocurrencyTradeFee, - feeBuilder.FeeType) - } + assert.Equal(t, exchange.CryptocurrencyTradeFee, feeBuilder.FeeType) } } @@ -265,7 +228,7 @@ func TestGetFee(t *testing.T) { if sharedtestvalues.AreAPICredentialsSet(e) || mockTests { // CryptocurrencyTradeFee Basic if _, err := e.GetFee(t.Context(), feeBuilder); err != nil { - t.Error(err) + assert.NoError(t, err) } // CryptocurrencyTradeFee High quantity @@ -273,28 +236,28 @@ func TestGetFee(t *testing.T) { feeBuilder.Amount = 1000 feeBuilder.PurchasePrice = 1000 if _, err := e.GetFee(t.Context(), feeBuilder); err != nil { - t.Error(err) + assert.NoError(t, err) } // CryptocurrencyTradeFee IsMaker feeBuilder = setFeeBuilder() feeBuilder.IsMaker = true if _, err := e.GetFee(t.Context(), feeBuilder); err != nil { - t.Error(err) + assert.NoError(t, err) } // CryptocurrencyTradeFee Negative purchase price feeBuilder = setFeeBuilder() feeBuilder.PurchasePrice = -1000 if _, err := e.GetFee(t.Context(), feeBuilder); err != nil { - t.Error(err) + assert.NoError(t, err) } } // CryptocurrencyWithdrawalFee Basic feeBuilder = setFeeBuilder() feeBuilder.FeeType = exchange.CryptocurrencyWithdrawalFee if _, err := e.GetFee(t.Context(), feeBuilder); err != nil { - t.Error(err) + assert.NoError(t, err) } // CryptocurrencyWithdrawalFee Invalid currency @@ -302,21 +265,21 @@ func TestGetFee(t *testing.T) { feeBuilder.Pair.Base = currency.NewCode("hello") feeBuilder.FeeType = exchange.CryptocurrencyWithdrawalFee if _, err := e.GetFee(t.Context(), feeBuilder); err != nil { - t.Error(err) + assert.NoError(t, err) } // CryptocurrencyDepositFee Basic feeBuilder = setFeeBuilder() feeBuilder.FeeType = exchange.CryptocurrencyDepositFee if _, err := e.GetFee(t.Context(), feeBuilder); err != nil { - t.Error(err) + assert.NoError(t, err) } // InternationalBankDepositFee Basic feeBuilder = setFeeBuilder() feeBuilder.FeeType = exchange.InternationalBankDepositFee if _, err := e.GetFee(t.Context(), feeBuilder); err != nil { - t.Error(err) + assert.NoError(t, err) } // InternationalBankWithdrawalFee Basic @@ -324,7 +287,7 @@ func TestGetFee(t *testing.T) { feeBuilder.FeeType = exchange.InternationalBankWithdrawalFee feeBuilder.FiatCurrency = currency.USD if _, err := e.GetFee(t.Context(), feeBuilder); err != nil { - t.Error(err) + assert.NoError(t, err) } } @@ -336,11 +299,7 @@ func TestFormatWithdrawPermissions(t *testing.T) { " & " + exchange.WithdrawFiatViaWebsiteOnlyText withdrawPermissions := e.FormatWithdrawPermissions() - if withdrawPermissions != expectedResult { - t.Errorf("Expected: %s, Received: %s", - expectedResult, - withdrawPermissions) - } + assert.Equal(t, expectedResult, withdrawPermissions) } func TestGetActiveOrders(t *testing.T) { @@ -357,11 +316,11 @@ func TestGetActiveOrders(t *testing.T) { _, err := e.GetActiveOrders(t.Context(), &getOrdersRequest) switch { case sharedtestvalues.AreAPICredentialsSet(e) && err != nil && !mockTests: - t.Errorf("Could not get open orders: %s", err) + assert.NoError(t, err, "GetActiveOrders should not error") case !sharedtestvalues.AreAPICredentialsSet(e) && err == nil && !mockTests: - t.Error("Expecting an error when no keys are set") + assert.Error(t, err, "GetActiveOrders should error when no keys are set") case mockTests && err != nil: - t.Errorf("Could not get open orders: %s", err) + assert.NoError(t, err, "GetActiveOrders should not error") } } @@ -377,11 +336,11 @@ func TestGetOrderHistory(t *testing.T) { _, err := e.GetOrderHistory(t.Context(), &getOrdersRequest) switch { case sharedtestvalues.AreAPICredentialsSet(e) && err != nil: - t.Errorf("Could not get order history: %s", err) + assert.NoError(t, err, "GetOrderHistory should not error") case !sharedtestvalues.AreAPICredentialsSet(e) && err == nil && !mockTests: - t.Error("Expecting an error when no keys are set") + assert.Error(t, err, "GetOrderHistory should error when no keys are set") case err != nil && mockTests: - t.Errorf("Could not get order history: %s", err) + assert.NoError(t, err, "GetOrderHistory should not error") } } @@ -410,11 +369,12 @@ func TestSubmitOrder(t *testing.T) { response, err := e.SubmitOrder(t.Context(), orderSubmission) switch { case sharedtestvalues.AreAPICredentialsSet(e) && (err != nil || response.Status != order.New): - t.Errorf("Order failed to be placed: %v", err) + assert.NoError(t, err, "SubmitOrder should not error") + assert.Equal(t, order.New, response.Status, "SubmitOrder should return order.New status") case !sharedtestvalues.AreAPICredentialsSet(e) && err == nil && !mockTests: - t.Error("Expecting an error when no keys are set") + assert.Error(t, err, "SubmitOrder should error when no keys are set") case mockTests && err != nil: - t.Errorf("Order failed to be placed: %v", err) + assert.NoError(t, err, "SubmitOrder should not error") } } @@ -432,11 +392,11 @@ func TestCancelExchangeOrder(t *testing.T) { err := e.CancelOrder(t.Context(), orderCancellation) switch { case !sharedtestvalues.AreAPICredentialsSet(e) && err == nil && !mockTests: - t.Error("Expecting an error when no keys are set") + assert.Error(t, err, "CancelOrder should error when no keys are set") case sharedtestvalues.AreAPICredentialsSet(e) && err != nil: - t.Errorf("Could not cancel orders: %v", err) + assert.NoError(t, err, "CancelOrder should not error") case err != nil && mockTests: - t.Errorf("Could not cancel orders: %v", err) + assert.NoError(t, err, "CancelOrder should not error") } } @@ -457,16 +417,14 @@ func TestCancelAllExchangeOrders(t *testing.T) { resp, err := e.CancelAllOrders(t.Context(), orderCancellation) switch { case !sharedtestvalues.AreAPICredentialsSet(e) && err == nil && !mockTests: - t.Error("Expecting an error when no keys are set") + assert.Error(t, err, "CancelAllOrders should error when no keys are set") case sharedtestvalues.AreAPICredentialsSet(e) && err != nil: - t.Errorf("Could not cancel orders: %v", err) + assert.NoError(t, err, "CancelAllOrders should not error") case mockTests && err != nil: - t.Errorf("Could not cancel orders: %v", err) + assert.NoError(t, err, "CancelAllOrders should not error") } - if len(resp.Status) > 0 { - t.Errorf("%v orders failed to cancel", len(resp.Status)) - } + assert.Empty(t, resp.Status, "CancelAllOrders should return zero failed statuses") } func TestModifyOrder(t *testing.T) { @@ -474,9 +432,7 @@ func TestModifyOrder(t *testing.T) { sharedtestvalues.SkipTestIfCannotManipulateOrders(t, e, canManipulateRealOrders) _, err := e.ModifyOrder(t.Context(), &order.Modify{AssetType: asset.Spot}) - if err == nil { - t.Error("ModifyOrder() Expected error") - } + assert.Error(t, err, "ModifyOrder should error for incomplete request") } func TestWithdraw(t *testing.T) { @@ -497,13 +453,13 @@ func TestWithdraw(t *testing.T) { }, }) if !sharedtestvalues.AreAPICredentialsSet(e) && err == nil { - t.Error("Expecting an error when no keys are set") + assert.Error(t, err, "Withdraw should error when no keys are set") } if sharedtestvalues.AreAPICredentialsSet(e) && err != nil && !mockTests { - t.Errorf("Withdraw failed to be placed: %v", err) + assert.NoError(t, err, "Withdraw should not error") } if sharedtestvalues.AreAPICredentialsSet(e) && err == nil && mockTests { - t.Errorf("Withdraw failed to be placed: %v", err) + assert.Error(t, err, "Withdraw should error in mock mode with credentials") } } @@ -515,11 +471,7 @@ func TestWithdrawFiat(t *testing.T) { withdrawFiatRequest := withdraw.Request{} _, err := e.WithdrawFiatFunds(t.Context(), &withdrawFiatRequest) - if err != common.ErrFunctionNotSupported { - t.Errorf("Expected '%v', received: '%v'", - common.ErrFunctionNotSupported, - err) - } + assert.ErrorIs(t, err, common.ErrFunctionNotSupported) } func TestWithdrawInternationalBank(t *testing.T) { @@ -531,47 +483,26 @@ func TestWithdrawInternationalBank(t *testing.T) { withdrawFiatRequest := withdraw.Request{} _, err := e.WithdrawFiatFundsToInternationalBank(t.Context(), &withdrawFiatRequest) - if err != common.ErrFunctionNotSupported { - t.Errorf("Expected '%v', received: '%v'", - common.ErrFunctionNotSupported, - err) - } + assert.ErrorIs(t, err, common.ErrFunctionNotSupported) } func TestGetDepositAddress(t *testing.T) { t.Parallel() _, err := e.GetDepositAddress(t.Context(), currency.BTC, "", "") - if err == nil { - t.Error("GetDepositAddress error cannot be nil") - } + assert.Error(t, err, "GetDepositAddress should error when account details are missing") } func TestWsAuth(t *testing.T) { t.Parallel() - err := e.API.Endpoints.SetRunningURL(exchange.WebsocketSpot.String(), geminiWebsocketSandboxEndpoint) - if err != nil { - t.Error(err) + if !e.Websocket.IsEnabled() || + !e.API.AuthenticatedWebsocketSupport || + !sharedtestvalues.AreAPICredentialsSet(e) { + t.Skip("authenticated websocket is not available for this test") } testexch.SkipTestIfCannotUseAuthenticatedWebsocket(t, e) - var dialer gws.Dialer - err = e.WsAuth(t.Context(), &dialer) - if err != nil { - t.Error(err) - } - timer := time.NewTimer(sharedtestvalues.WebsocketResponseDefaultTimeout) - select { - case resp := <-e.Websocket.DataHandler.C: - subAck, ok := resp.Data.(WsSubscriptionAcknowledgementResponse) - if !ok { - t.Error("unable to type assert WsSubscriptionAcknowledgementResponse") - } - if subAck.Type != "subscription_ack" { - t.Error("Login failed") - } - case <-timer.C: - t.Error("Expected response") - } - timer.Stop() + require.NoError(t, e.API.Endpoints.SetRunningURL(exchange.WebsocketSpotSupplementary.String(), geminiWebsocketSandboxEndpoint+geminiWsOrderEvents)) + conn := testexch.GetMockConn(t, e, geminiWebsocketSandboxEndpoint+geminiWsOrderEvents) + require.NoError(t, e.wsAuthConnect(t.Context(), conn)) } func TestWsMissingRole(t *testing.T) { @@ -580,9 +511,7 @@ func TestWsMissingRole(t *testing.T) { "reason":"MissingRole", "message":"To access this endpoint, you need to log in to the website and go to the settings page to assign one of these roles [FundManager] to API key wujB3szN54gtJ4QDhqRJ which currently has roles [Trader]" }`) - if err := e.wsHandleData(t.Context(), pressXToJSON); err == nil { - t.Error("Expected error") - } + assert.Error(t, e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON), "wsHandleData should return an error") } func TestWsOrderEventSubscriptionResponse(t *testing.T) { @@ -604,10 +533,7 @@ func TestWsOrderEventSubscriptionResponse(t *testing.T) { "original_amount" : "14.0296", "price" : "1059.54" } ]`) - err := e.wsHandleData(t.Context(), pressXToJSON) - if err != nil { - t.Error(err) - } + assert.NoError(t, e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON), "wsHandleData should not error") pressXToJSON = []byte(`[{ "type": "accepted", @@ -626,10 +552,7 @@ func TestWsOrderEventSubscriptionResponse(t *testing.T) { "price": "3592.00", "socket_sequence": 13 }]`) - err = e.wsHandleData(t.Context(), pressXToJSON) - if err != nil { - t.Error(err) - } + assert.NoError(t, e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON), "wsHandleData should not error") pressXToJSON = []byte(`[{ "type": "accepted", @@ -647,10 +570,7 @@ func TestWsOrderEventSubscriptionResponse(t *testing.T) { "total_spend": "200.00", "socket_sequence": 29 }]`) - err = e.wsHandleData(t.Context(), pressXToJSON) - if err != nil { - t.Error(err) - } + assert.NoError(t, e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON), "wsHandleData should not error") pressXToJSON = []byte(`[{ "type": "accepted", @@ -668,10 +588,7 @@ func TestWsOrderEventSubscriptionResponse(t *testing.T) { "original_amount": "25", "socket_sequence": 26 }]`) - err = e.wsHandleData(t.Context(), pressXToJSON) - if err != nil { - t.Error(err) - } + assert.NoError(t, e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON), "wsHandleData should not error") pressXToJSON = []byte(`[ { "type" : "accepted", @@ -690,10 +607,7 @@ func TestWsOrderEventSubscriptionResponse(t *testing.T) { "original_amount" : "500", "socket_sequence" : 32307 } ]`) - err = e.wsHandleData(t.Context(), pressXToJSON) - if err != nil { - t.Error(err) - } + assert.NoError(t, e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON), "wsHandleData should not error") } func TestWsSubAck(t *testing.T) { @@ -712,8 +626,8 @@ func TestWsSubAck(t *testing.T) { "closed" ] }`) - if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { - t.Error(err) + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON); err != nil { + assert.NoError(t, err) } } @@ -725,8 +639,8 @@ func TestWsHeartbeat(t *testing.T) { "trace_id": "b8biknoqppr32kc7gfgg", "socket_sequence": 37 }`) - if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { - t.Error(err) + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON); err != nil { + assert.NoError(t, err) } } @@ -746,10 +660,7 @@ func TestWsUnsubscribe(t *testing.T) { ]} ] }`) - err := e.wsHandleData(t.Context(), pressXToJSON) - if err != nil { - t.Error(err) - } + assert.NoError(t, e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON), "wsHandleData should not error") } func TestWsTradeData(t *testing.T) { @@ -769,8 +680,8 @@ func TestWsTradeData(t *testing.T) { } ] }`) - if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { - t.Error(err) + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON); err != nil { + assert.NoError(t, err) } } @@ -801,8 +712,8 @@ func TestWsAuctionData(t *testing.T) { ], "type": "update" }`) - if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { - t.Error(err) + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON); err != nil { + assert.NoError(t, err) } } @@ -822,8 +733,8 @@ func TestWsBlockTrade(t *testing.T) { } ] }`) - if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { - t.Error(err) + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON); err != nil { + assert.NoError(t, err) } } @@ -837,8 +748,8 @@ func TestWSTrade(t *testing.T) { "quantity": "0.09110000", "side": "buy" }`) - if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { - t.Error(err) + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON); err != nil { + assert.NoError(t, err) } } @@ -870,7 +781,7 @@ func TestWsCandles(t *testing.T) { ] ] }`) - require.NoError(t, g.wsHandleData(t.Context(), pressXToJSON)) + require.NoError(t, g.wsHandleData(t.Context(), testexch.GetMockConn(t, g, ""), pressXToJSON)) for _, exp := range []kline.Candle{ { @@ -943,8 +854,8 @@ func TestWsAuctions(t *testing.T) { ], "type": "update" }`) - if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { - t.Error(err) + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON); err != nil { + assert.NoError(t, err) } pressXToJSON = []byte(`{ @@ -967,8 +878,8 @@ func TestWsAuctions(t *testing.T) { } ] }`) - if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { - t.Error(err) + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON); err != nil { + assert.NoError(t, err) } pressXToJSON = []byte(`{ @@ -998,8 +909,8 @@ func TestWsAuctions(t *testing.T) { } ] }`) - if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { - t.Error(err) + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON); err != nil { + assert.NoError(t, err) } } @@ -1027,10 +938,7 @@ func TestWsMarketData(t *testing.T) { } ] } `) - err := e.wsHandleData(t.Context(), pressXToJSON) - if err != nil { - t.Error(err) - } + assert.NoError(t, e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON), "wsHandleData should not error") pressXToJSON = []byte(`{ "type": "update", @@ -1055,10 +963,7 @@ func TestWsMarketData(t *testing.T) { } ] } `) - err = e.wsHandleData(t.Context(), pressXToJSON) - if err != nil { - t.Error(err) - } + assert.NoError(t, e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON), "wsHandleData should not error") pressXToJSON = []byte(`{ "type": "update", @@ -1077,52 +982,60 @@ func TestWsMarketData(t *testing.T) { } ] } `) - err = e.wsHandleData(t.Context(), pressXToJSON) - if err != nil { - t.Error(err) - } + assert.NoError(t, e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON), "wsHandleData should not error") } func TestWsError(t *testing.T) { + t.Parallel() tt := []struct { + Case string Data []byte ErrorExpected bool ErrorShouldContain string }{ { + Case: "no error payload type", Data: []byte(`{"type": "test"}`), ErrorExpected: false, }, { + Case: "no error payload result", Data: []byte(`{"result": "bla"}`), ErrorExpected: false, }, { + Case: "generic websocket error", Data: []byte(`{"result": "error"}`), ErrorExpected: true, ErrorShouldContain: "Unhandled websocket error", }, { + Case: "invalid json reason", Data: []byte(`{"result": "error","reason": "InvalidJson"}`), ErrorExpected: true, ErrorShouldContain: "InvalidJson", }, { + Case: "invalid json reason with message", Data: []byte(`{"result": "error","reason": "InvalidJson", "message": "WeAreGoingToTheMoonKirby"}`), ErrorExpected: true, ErrorShouldContain: "InvalidJson - WeAreGoingToTheMoonKirby", }, } - for x := range tt { - err := e.wsHandleData(t.Context(), tt[x].Data) - if tt[x].ErrorExpected && err != nil && !strings.Contains(err.Error(), tt[x].ErrorShouldContain) { - t.Errorf("expected error to contain: %s, got: %s", - tt[x].ErrorShouldContain, err.Error(), - ) - } else if !tt[x].ErrorExpected && err != nil { - t.Errorf("unexpected error: %s", err) - } + for _, tc := range tt { + t.Run(tc.Case, func(t *testing.T) { + t.Parallel() + ex := new(Exchange) + require.NoError(t, testexch.Setup(ex), "Test instance Setup must not error") + + err := ex.wsHandleData(t.Context(), testexch.GetMockConn(t, ex, ""), tc.Data) + if tc.ErrorExpected && err != nil && !strings.Contains(err.Error(), tc.ErrorShouldContain) { + assert.Contains(t, err.Error(), tc.ErrorShouldContain, "error should contain expected substring") + } else if !tc.ErrorExpected && err != nil { + assert.NoError(t, err, "wsHandleData should not error for this payload") + } + }) } } @@ -1175,12 +1088,13 @@ func TestWsLevel2Update(t *testing.T) { } ] }`) - if err := e.wsHandleData(t.Context(), pressXToJSON); err != nil { - t.Error(err) + if err := e.wsHandleData(t.Context(), testexch.GetMockConn(t, e, ""), pressXToJSON); err != nil { + assert.NoError(t, err) } } func TestResponseToStatus(t *testing.T) { + t.Parallel() type TestCases struct { Case string Result order.Status @@ -1195,14 +1109,16 @@ func TestResponseToStatus(t *testing.T) { {Case: "LOL", Result: order.UnknownStatus}, } for i := range testCases { - result, _ := stringToOrderStatus(testCases[i].Case) - if result != testCases[i].Result { - t.Errorf("Expected: %v, received: %v", testCases[i].Result, result) - } + t.Run(testCases[i].Case, func(t *testing.T) { + t.Parallel() + result, _ := stringToOrderStatus(testCases[i].Case) + assert.Equal(t, testCases[i].Result, result, "order status should match expected conversion") + }) } } func TestResponseToOrderType(t *testing.T) { + t.Parallel() type TestCases struct { Case string Result order.Type @@ -1217,10 +1133,11 @@ func TestResponseToOrderType(t *testing.T) { {Case: "LOL", Result: order.UnknownType}, } for i := range testCases { - result, _ := stringToOrderType(testCases[i].Case) - if result != testCases[i].Result { - t.Errorf("Expected: %v, received: %v", testCases[i].Result, result) - } + t.Run(testCases[i].Case, func(t *testing.T) { + t.Parallel() + result, _ := stringToOrderType(testCases[i].Case) + assert.Equal(t, testCases[i].Result, result, "order type should match expected conversion") + }) } } @@ -1228,11 +1145,11 @@ func TestGetRecentTrades(t *testing.T) { t.Parallel() currencyPair, err := currency.NewPairFromString(testCurrency) if err != nil { - t.Fatal(err) + require.NoError(t, err) } _, err = e.GetRecentTrades(t.Context(), currencyPair, asset.Spot) if err != nil { - t.Error(err) + assert.NoError(t, err) } } @@ -1240,7 +1157,7 @@ func TestGetHistoricTrades(t *testing.T) { t.Parallel() currencyPair, err := currency.NewPairFromString(testCurrency) if err != nil { - t.Fatal(err) + require.NoError(t, err) } tStart := time.Date(2020, 6, 6, 0, 0, 0, 0, time.UTC) tEnd := time.Date(2020, 6, 7, 0, 0, 0, 0, time.UTC) @@ -1251,7 +1168,7 @@ func TestGetHistoricTrades(t *testing.T) { _, err = e.GetHistoricTrades(t.Context(), currencyPair, asset.Spot, tStart, tEnd) if err != nil { - t.Error(err) + assert.NoError(t, err) } } @@ -1261,7 +1178,7 @@ func TestTransfers(t *testing.T) { _, err := e.Transfers(t.Context(), currency.BTC, time.Time{}, 100, "", true) if err != nil { - t.Error(err) + assert.NoError(t, err) } } @@ -1271,7 +1188,7 @@ func TestGetAccountFundingHistory(t *testing.T) { _, err := e.GetAccountFundingHistory(t.Context()) if err != nil { - t.Error(err) + assert.NoError(t, err) } } @@ -1281,7 +1198,7 @@ func TestGetWithdrawalsHistory(t *testing.T) { _, err := e.GetWithdrawalsHistory(t.Context(), currency.BTC, asset.Spot) if err != nil { - t.Error(err) + assert.NoError(t, err) } } @@ -1291,7 +1208,7 @@ func TestGetOrderInfo(t *testing.T) { _, err := e.GetOrderInfo(t.Context(), "1234", currency.EMPTYPAIR, asset.Empty) if err != nil { - t.Error(err) + assert.NoError(t, err) } } @@ -1299,11 +1216,11 @@ func TestGetSymbolDetails(t *testing.T) { t.Parallel() _, err := e.GetSymbolDetails(t.Context(), "all") if err != nil { - t.Error(err) + assert.NoError(t, err) } _, err = e.GetSymbolDetails(t.Context(), "btcusd") if err != nil { - t.Error(err) + assert.NoError(t, err) } } @@ -1333,12 +1250,15 @@ func TestGetCurrencyTradeURL(t *testing.T) { t.Parallel() testexch.UpdatePairsOnce(t, e) for _, a := range e.GetAssetTypes(false) { - pairs, err := e.CurrencyPairs.GetPairs(a, false) - require.NoErrorf(t, err, "cannot get pairs for %s", a) - require.NotEmptyf(t, pairs, "no pairs for %s", a) - resp, err := e.GetCurrencyTradeURL(t.Context(), a, pairs[0]) - require.NoError(t, err) - assert.NotEmpty(t, resp) + t.Run(a.String(), func(t *testing.T) { + t.Parallel() + pairs, err := e.CurrencyPairs.GetPairs(a, false) + require.NoErrorf(t, err, "GetPairs must not error for asset %s", a) + require.NotEmptyf(t, pairs, "pairs must not be empty for asset %s", a) + resp, err := e.GetCurrencyTradeURL(t.Context(), a, pairs[0]) + require.NoError(t, err, "GetCurrencyTradeURL must not error") + assert.NotEmpty(t, resp, "GetCurrencyTradeURL should return a URL") + }) } } @@ -1358,10 +1278,13 @@ func TestGenerateSubscriptions(t *testing.T) { testsubs.EqualLists(t, exp, subs) for _, i := range []kline.Interval{kline.OneMin, kline.FiveMin, kline.FifteenMin, kline.ThirtyMin, kline.OneHour, kline.SixHour} { - subs, err = subscription.List{{Asset: asset.Spot, Channel: subscription.CandlesChannel, Pairs: p, Interval: i}}.ExpandTemplates(e) - assert.NoErrorf(t, err, "ExpandTemplates should not error on interval %s", i) - require.NotEmpty(t, subs) - assert.Equal(t, "candles_"+i.Short(), subs[0].QualifiedChannel) + t.Run(i.String(), func(t *testing.T) { + t.Parallel() + subRes, err := subscription.List{{Asset: asset.Spot, Channel: subscription.CandlesChannel, Pairs: p, Interval: i}}.ExpandTemplates(e) + assert.NoErrorf(t, err, "ExpandTemplates should not error on interval %s", i) + require.NotEmpty(t, subRes, "ExpandTemplates result must not be empty") + assert.Equal(t, "candles_"+i.Short(), subRes[0].QualifiedChannel, "qualified channel should match expected interval") + }) } _, err = subscription.List{{Asset: asset.Spot, Channel: subscription.CandlesChannel, Pairs: p, Interval: kline.FourHour}}.ExpandTemplates(e) assert.ErrorIs(t, err, kline.ErrUnsupportedInterval, "ExpandTemplates should error on invalid interval") diff --git a/exchanges/gemini/gemini_websocket.go b/exchanges/gemini/gemini_websocket.go index d68b39f2c08..367bfe81f7d 100644 --- a/exchanges/gemini/gemini_websocket.go +++ b/exchanges/gemini/gemini_websocket.go @@ -27,7 +27,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" - "github.com/thrasher-corp/gocryptotrader/log" ) const ( @@ -54,30 +53,8 @@ var subscriptionNames = map[string]string{ subscription.OrderbookChannel: marketDataLevel2, } -// WsConnect initiates a websocket connection -func (e *Exchange) WsConnect() error { - ctx := context.TODO() - if !e.Websocket.IsEnabled() || !e.IsEnabled() { - return websocket.ErrWebsocketNotEnabled - } - - var dialer gws.Dialer - err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{}, nil) - if err != nil { - return err - } - - e.Websocket.Wg.Add(1) - go e.wsReadData(ctx, e.Websocket.Conn) - - if e.Websocket.CanUseAuthenticatedEndpoints() { - err := e.WsAuth(ctx, &dialer) - if err != nil { - log.Errorf(log.ExchangeSys, "%v - websocket authentication failed: %v\n", e.Name, err) - e.Websocket.SetCanUseAuthenticatedEndpoints(false) - } - } - return nil +func (e *Exchange) wsConnect(ctx context.Context, conn websocket.Connection) error { + return conn.Dial(ctx, &gws.Dialer{}, http.Header{}, nil) } // generateSubscriptions returns a list of subscriptions from the configured subscriptions feature @@ -85,6 +62,14 @@ func (e *Exchange) generateSubscriptions() (subscription.List, error) { return e.Features.Subscriptions.ExpandTemplates(e) } +func (e *Exchange) generatePublicSubscriptions() (subscription.List, error) { + subs, err := e.generateSubscriptions() + if err != nil { + return nil, err + } + return subs.Public(), nil +} + // GetSubscriptionTemplate returns a subscription channel template func (e *Exchange) GetSubscriptionTemplate(_ *subscription.Subscription) (*template.Template, error) { return template.New("master.tmpl").Funcs(template.FuncMap{ @@ -93,19 +78,15 @@ func (e *Exchange) GetSubscriptionTemplate(_ *subscription.Subscription) (*templ }).Parse(subTplText) } -// Subscribe sends a websocket message to receive data from the channel -func (e *Exchange) Subscribe(subs subscription.List) error { - ctx := context.TODO() - return e.manageSubs(ctx, subs, wsSubscribeOp) +func (e *Exchange) subscribeForConnection(ctx context.Context, conn websocket.Connection, subs subscription.List) error { + return e.manageSubs(ctx, conn, subs, wsSubscribeOp) } -// Unsubscribe sends a websocket message to stop receiving data from the channel -func (e *Exchange) Unsubscribe(subs subscription.List) error { - ctx := context.TODO() - return e.manageSubs(ctx, subs, wsUnsubscribeOp) +func (e *Exchange) unsubscribeForConnection(ctx context.Context, conn websocket.Connection, subs subscription.List) error { + return e.manageSubs(ctx, conn, subs, wsUnsubscribeOp) } -func (e *Exchange) manageSubs(ctx context.Context, subs subscription.List, op wsSubOp) error { +func (e *Exchange) manageSubs(ctx context.Context, conn websocket.Connection, subs subscription.List, op wsSubOp) error { req := wsSubscribeRequest{ Type: op, Subscriptions: make([]wsSubscriptions, 0, len(subs)), @@ -117,22 +98,18 @@ func (e *Exchange) manageSubs(ctx context.Context, subs subscription.List, op ws }) } - if err := e.Websocket.Conn.SendJSONMessage(ctx, request.Unset, req); err != nil { + if err := conn.SendJSONMessage(ctx, request.Unset, req); err != nil { return err } if op == wsUnsubscribeOp { - return e.Websocket.RemoveSubscriptions(e.Websocket.Conn, subs...) + return e.Websocket.RemoveSubscriptions(conn, subs...) } - return e.Websocket.AddSuccessfulSubscriptions(e.Websocket.Conn, subs...) + return e.Websocket.AddSuccessfulSubscriptions(conn, subs...) } -// WsAuth will connect to Gemini's secure endpoint -func (e *Exchange) WsAuth(ctx context.Context, dialer *gws.Dialer) error { - if !e.IsWebsocketAuthenticationSupported() { - return fmt.Errorf("%v AuthenticatedWebsocketAPISupport not enabled", e.Name) - } +func (e *Exchange) wsAuthConnect(ctx context.Context, conn websocket.Connection) error { creds, err := e.GetCredentials(ctx) if err != nil { return err @@ -145,11 +122,10 @@ func (e *Exchange) WsAuth(ctx context.Context, dialer *gws.Dialer) error { if err != nil { return fmt.Errorf("%v sendAuthenticatedHTTPRequest: Unable to JSON request", e.Name) } - wsEndpoint, err := e.API.Endpoints.GetURL(exchange.WebsocketSpot) + endpoint, err := e.API.Endpoints.GetURL(exchange.WebsocketSpotSupplementary) if err != nil { return err } - endpoint := wsEndpoint + geminiWsOrderEvents payloadB64 := base64.StdEncoding.EncodeToString(payloadJSON) hmac, err := crypto.GetHMAC(crypto.HashSHA512_384, []byte(payloadB64), []byte(creds.Secret)) if err != nil { @@ -164,31 +140,14 @@ func (e *Exchange) WsAuth(ctx context.Context, dialer *gws.Dialer) error { headers.Add("X-GEMINI-SIGNATURE", hex.EncodeToString(hmac)) headers.Add("Cache-Control", "no-cache") - err = e.Websocket.AuthConn.Dial(ctx, dialer, headers, nil) + err = conn.Dial(ctx, &gws.Dialer{}, headers, nil) if err != nil { return fmt.Errorf("%v Websocket connection %v error. Error %v", e.Name, endpoint, err) } - e.Websocket.Wg.Add(1) - go e.wsReadData(ctx, e.Websocket.AuthConn) return nil } -func (e *Exchange) wsReadData(ctx context.Context, ws websocket.Connection) { - defer e.Websocket.Wg.Done() - for { - resp := ws.ReadMessage() - if resp.Raw == nil { - return - } - if err := e.wsHandleData(ctx, resp.Raw); err != nil { - if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { - log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, ws.GetURL(), errSend, err) - } - } - } -} - -func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { +func (e *Exchange) wsHandleData(ctx context.Context, _ websocket.Connection, respRaw []byte) error { // only order details are sent in arrays if strings.HasPrefix(string(respRaw), "[") { var result []WsOrderResponse diff --git a/exchanges/gemini/gemini_wrapper.go b/exchanges/gemini/gemini_wrapper.go index 7e678865dab..198f07c2009 100644 --- a/exchanges/gemini/gemini_wrapper.go +++ b/exchanges/gemini/gemini_wrapper.go @@ -102,8 +102,9 @@ func (e *Exchange) SetDefaults() { } e.API.Endpoints = e.NewEndpoints() err = e.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ - exchange.RestSpot: geminiAPIURL, - exchange.WebsocketSpot: geminiWebsocketEndpoint, + exchange.RestSpot: geminiAPIURL, + exchange.WebsocketSpot: geminiWebsocketEndpoint + "/v2/" + geminiWsMarketData, + exchange.WebsocketSpotSupplementary: geminiWebsocketEndpoint + "/v1/" + geminiWsOrderEvents, }) if err != nil { log.Errorln(log.ExchangeSys, err) @@ -136,39 +137,51 @@ func (e *Exchange) Setup(exch *config.Exchange) error { } } - wsRunningURL, err := e.API.Endpoints.GetURL(exchange.WebsocketSpot) + wsPublicURL, err := e.API.Endpoints.GetURL(exchange.WebsocketSpot) if err != nil { return err } err = e.Websocket.Setup(&websocket.ManagerSetup{ - ExchangeConfig: exch, - DefaultURL: geminiWebsocketEndpoint, - RunningURL: wsRunningURL, - Connector: e.WsConnect, - Subscriber: e.Subscribe, - Unsubscriber: e.Unsubscribe, - GenerateSubscriptions: e.generateSubscriptions, - Features: &e.Features.Supports.WebsocketCapabilities, + ExchangeConfig: exch, + UseMultiConnectionManagement: true, + Features: &e.Features.Supports.WebsocketCapabilities, }) if err != nil { return err } err = e.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - URL: geminiWebsocketEndpoint + "/v2/" + geminiWsMarketData, + URL: wsPublicURL, + Connector: e.wsConnect, + Subscriber: e.subscribeForConnection, + Unsubscriber: e.unsubscribeForConnection, + GenerateSubscriptions: e.generatePublicSubscriptions, + Handler: e.wsHandleData, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + MessageFilter: wsPublicURL, }) if err != nil { return err } + authWSURL, err := e.API.Endpoints.GetURL(exchange.WebsocketSpotSupplementary) + if err != nil { + return err + } + if !exch.API.AuthenticatedWebsocketSupport { + return nil + } + return e.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - URL: geminiWebsocketEndpoint + "/v1/" + geminiWsOrderEvents, - Authenticated: true, + URL: authWSURL, + Connector: e.wsAuthConnect, + SubscriptionsNotRequired: true, + Handler: e.wsHandleData, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + MessageFilter: authWSURL, }) } diff --git a/exchanges/hitbtc/hitbtc_wrapper.go b/exchanges/hitbtc/hitbtc_wrapper.go index 6b64d961748..7bdd144646d 100644 --- a/exchanges/hitbtc/hitbtc_wrapper.go +++ b/exchanges/hitbtc/hitbtc_wrapper.go @@ -718,11 +718,6 @@ func (e *Exchange) GetOrderHistory(ctx context.Context, req *order.MultiOrderReq return req.Filter(e.Name, orders), nil } -// AuthenticateWebsocket sends an authentication message to the websocket -func (e *Exchange) AuthenticateWebsocket(ctx context.Context) error { - return e.wsLogin(ctx) -} - // ValidateAPICredentials validates current credentials used for wrapper functionality func (e *Exchange) ValidateAPICredentials(ctx context.Context, assetType asset.Item) error { _, err := e.UpdateAccountBalances(ctx, assetType) diff --git a/exchanges/huobi/huobi_test.go b/exchanges/huobi/huobi_test.go index afd3a1ecac5..19e4a0ee910 100644 --- a/exchanges/huobi/huobi_test.go +++ b/exchanges/huobi/huobi_test.go @@ -1,8 +1,7 @@ package huobi import ( - "errors" - "fmt" + "context" "log" "os" "strconv" @@ -10,8 +9,6 @@ import ( "testing" "time" - "github.com/buger/jsonparser" - gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" @@ -31,7 +28,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/trade" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" - mockws "github.com/thrasher-corp/gocryptotrader/internal/testing/websocket" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" "github.com/thrasher-corp/gocryptotrader/types" ) @@ -1286,9 +1282,10 @@ func TestWSCandles(t *testing.T) { t.Parallel() e := new(Exchange) require.NoError(t, testexch.Setup(e), "Setup Instance must not error") - err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Key: "market.btcusdt.kline.1min", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.CandlesChannel}) + err := e.Websocket.AddSubscriptions(nil, &subscription.Subscription{Key: "market.btcusdt.kline.1min", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.CandlesChannel}) require.NoError(t, err, "AddSubscriptions must not error") - testexch.FixtureToDataHandler(t, "testdata/wsCandles.json", e.wsHandleData) + conn := testexch.GetMockConn(t, e, "") + testexch.FixtureToDataHandler(t, "testdata/wsCandles.json", func(ctx context.Context, b []byte) error { return e.wsHandleData(ctx, conn, b) }) e.Websocket.DataHandler.Close() require.Len(t, e.Websocket.DataHandler.C, 1, "Must see correct number of records") cAny := <-e.Websocket.DataHandler.C @@ -1316,9 +1313,10 @@ func TestWSOrderbook(t *testing.T) { t.Parallel() e := new(Exchange) require.NoError(t, testexch.Setup(e), "Setup Instance must not error") - err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Key: "market.btcusdt.depth.step0", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.OrderbookChannel}) + err := e.Websocket.AddSubscriptions(nil, &subscription.Subscription{Key: "market.btcusdt.depth.step0", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.OrderbookChannel}) require.NoError(t, err, "AddSubscriptions must not error") - testexch.FixtureToDataHandler(t, "testdata/wsOrderbook.json", e.wsHandleData) + conn := testexch.GetMockConn(t, e, "") + testexch.FixtureToDataHandler(t, "testdata/wsOrderbook.json", func(ctx context.Context, b []byte) error { return e.wsHandleData(ctx, conn, b) }) e.Websocket.DataHandler.Close() require.Len(t, e.Websocket.DataHandler.C, 1, "Must see correct number of records") dAny := <-e.Websocket.DataHandler.C @@ -1344,10 +1342,11 @@ func TestWSHandleAllTradesMsg(t *testing.T) { t.Parallel() e := new(Exchange) require.NoError(t, testexch.Setup(e), "Setup Instance must not error") - err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Key: "market.btcusdt.trade.detail", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.AllTradesChannel}) + err := e.Websocket.AddSubscriptions(nil, &subscription.Subscription{Key: "market.btcusdt.trade.detail", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.AllTradesChannel}) require.NoError(t, err, "AddSubscriptions must not error") e.SetSaveTradeDataStatus(true) - testexch.FixtureToDataHandler(t, "testdata/wsAllTrades.json", e.wsHandleData) + conn := testexch.GetMockConn(t, e, "") + testexch.FixtureToDataHandler(t, "testdata/wsAllTrades.json", func(ctx context.Context, b []byte) error { return e.wsHandleData(ctx, conn, b) }) e.Websocket.DataHandler.Close() exp := []trade.Data{ { @@ -1390,9 +1389,10 @@ func TestWSTicker(t *testing.T) { t.Parallel() e := new(Exchange) require.NoError(t, testexch.Setup(e), "Setup Instance must not error") - err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Key: "market.btcusdt.detail", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.TickerChannel}) + err := e.Websocket.AddSubscriptions(nil, &subscription.Subscription{Key: "market.btcusdt.detail", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.TickerChannel}) require.NoError(t, err, "AddSubscriptions must not error") - testexch.FixtureToDataHandler(t, "testdata/wsTicker.json", e.wsHandleData) + conn := testexch.GetMockConn(t, e, "") + testexch.FixtureToDataHandler(t, "testdata/wsTicker.json", func(ctx context.Context, b []byte) error { return e.wsHandleData(ctx, conn, b) }) e.Websocket.DataHandler.Close() require.Len(t, e.Websocket.DataHandler.C, 1, "Must see correct number of records") tickAny := <-e.Websocket.DataHandler.C @@ -1419,10 +1419,11 @@ func TestWSAccountUpdate(t *testing.T) { t.Parallel() e := new(Exchange) require.NoError(t, testexch.Setup(e), "Setup Instance must not error") - err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Key: "accounts.update#2", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.MyAccountChannel}) + err := e.Websocket.AddSubscriptions(nil, &subscription.Subscription{Key: "accounts.update#2", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.MyAccountChannel}) require.NoError(t, err, "AddSubscriptions must not error") e.SetSaveTradeDataStatus(true) - testexch.FixtureToDataHandler(t, "testdata/wsMyAccount.json", e.wsHandleData) + conn := testexch.GetMockConn(t, e, "") + testexch.FixtureToDataHandler(t, "testdata/wsMyAccount.json", func(ctx context.Context, b []byte) error { return e.wsHandleData(ctx, conn, b) }) e.Websocket.DataHandler.Close() require.Len(t, e.Websocket.DataHandler.C, 3, "Must see correct number of records") exp := []WsAccountUpdate{ @@ -1443,10 +1444,11 @@ func TestWSOrderUpdate(t *testing.T) { t.Parallel() e := new(Exchange) require.NoError(t, testexch.Setup(e), "Setup Instance must not error") - err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Key: "orders#*", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.MyOrdersChannel}) + err := e.Websocket.AddSubscriptions(nil, &subscription.Subscription{Key: "orders#*", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.MyOrdersChannel}) require.NoError(t, err, "AddSubscriptions must not error") e.SetSaveTradeDataStatus(true) - errs := testexch.FixtureToDataHandlerWithErrors(t, "testdata/wsMyOrders.json", e.wsHandleData) + conn := testexch.GetMockConn(t, e, "") + errs := testexch.FixtureToDataHandlerWithErrors(t, "testdata/wsMyOrders.json", func(ctx context.Context, b []byte) error { return e.wsHandleData(ctx, conn, b) }) e.Websocket.DataHandler.Close() require.Equal(t, 1, len(errs), "Must receive the correct number of errors back") require.ErrorContains(t, errs[0].Err, "error with order \"test1\": invalid.client.order.id (NT) (2002)") @@ -1509,10 +1511,11 @@ func TestWSMyTrades(t *testing.T) { t.Parallel() e := new(Exchange) require.NoError(t, testexch.Setup(e), "Setup Instance must not error") - err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Key: "trade.clearing#btcusdt#1", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.MyTradesChannel}) + err := e.Websocket.AddSubscriptions(nil, &subscription.Subscription{Key: "trade.clearing#btcusdt#1", Asset: asset.Spot, Pairs: currency.Pairs{btcusdtPair}, Channel: subscription.MyTradesChannel}) require.NoError(t, err, "AddSubscriptions must not error") e.SetSaveTradeDataStatus(true) - testexch.FixtureToDataHandler(t, "testdata/wsMyTrades.json", e.wsHandleData) + conn := testexch.GetMockConn(t, e, "") + testexch.FixtureToDataHandler(t, "testdata/wsMyTrades.json", func(ctx context.Context, b []byte) error { return e.wsHandleData(ctx, conn, b) }) e.Websocket.DataHandler.Close() require.Len(t, e.Websocket.DataHandler.C, 1, "Must see correct number of records") m := <-e.Websocket.DataHandler.C @@ -1791,12 +1794,12 @@ func TestUpdateTickers(t *testing.T) { updatePairsOnce(t, e) for _, a := range e.GetAssetTypes(false) { err := e.UpdateTickers(t.Context(), a) - require.NoErrorf(t, err, "asset %s", a) + require.NoErrorf(t, err, "UpdateTicker must not error for asset %s", a) avail, err := e.GetAvailablePairs(a) require.NoError(t, err) for _, p := range avail { _, err = ticker.GetTicker(e.Name, p, a) - assert.NoErrorf(t, err, "Could not get ticker for %s %s", a, p) + assert.NoErrorf(t, err, "GetTicker should not error for %s %s", a, p) } } } @@ -1910,8 +1913,8 @@ func TestGetCurrencyTradeURL(t *testing.T) { updatePairsOnce(t, e) for _, a := range e.GetAssetTypes(false) { pairs, err := e.CurrencyPairs.GetPairs(a, false) - require.NoErrorf(t, err, "cannot get pairs for %s", a) - require.NotEmptyf(t, pairs, "no pairs for %s", a) + require.NoErrorf(t, err, "GetPairs must not error for asset %s", a) + require.NotEmptyf(t, pairs, "pairs must not be empty for asset %s", a) resp, err := e.GetCurrencyTradeURL(t.Context(), a, pairs[0]) require.NoError(t, err) assert.NotEmpty(t, resp) @@ -1992,58 +1995,6 @@ func TestGenerateSubscriptions(t *testing.T) { testsubs.EqualLists(t, exp, subs) } -func wsFixture(tb testing.TB, msg []byte, w *gws.Conn) error { - tb.Helper() - action, _ := jsonparser.GetString(msg, "action") - ch, _ := jsonparser.GetString(msg, "ch") - if action == "req" && ch == "auth" { - return w.WriteMessage(gws.TextMessage, []byte(`{"action":"req","code":200,"ch":"auth","data":{}}`)) - } - if action == "sub" { - return w.WriteMessage(gws.TextMessage, []byte(`{"action":"sub","code":200,"ch":"`+ch+`"}`)) - } - id, _ := jsonparser.GetString(msg, "id") - sub, _ := jsonparser.GetString(msg, "sub") - if id != "" && sub != "" { - return w.WriteMessage(gws.TextMessage, []byte(`{"id":"`+id+`","status":"ok","subbed":"`+sub+`"}`)) - } - return fmt.Errorf("%w: %s", errors.New("Unhandled mock websocket message"), msg) -} - -// TestSubscribe exercises live public subscriptions -func TestSubscribe(t *testing.T) { - t.Parallel() - e := new(Exchange) - require.NoError(t, testexch.Setup(e), "Test instance Setup must not error") - subs, err := e.Features.Subscriptions.ExpandTemplates(e) - require.NoError(t, err, "ExpandTemplates must not error") - testexch.SetupWs(t, e) - err = e.Subscribe(subs) - require.NoError(t, err, "Subscribe must not error") - got := e.Websocket.GetSubscriptions() - require.Equal(t, 8, len(got), "Must get correct number of subscriptions") - for _, s := range got { - assert.Equal(t, subscription.SubscribedState, s.State()) - } -} - -// TestAuthSubscribe exercises mock subscriptions including private -func TestAuthSubscribe(t *testing.T) { - t.Parallel() - subCfg := e.Features.Subscriptions - h := testexch.MockWsInstance[Exchange](t, mockws.CurryWsMockUpgrader(t, wsFixture)) - h.Websocket.SetCanUseAuthenticatedEndpoints(true) - subs, err := subCfg.ExpandTemplates(h) - require.NoError(t, err, "ExpandTemplates must not error") - err = h.Subscribe(subs) - require.NoError(t, err, "Subscribe must not error") - got := h.Websocket.GetSubscriptions() - require.Equal(t, 11, len(got), "Must get correct number of subscriptions") - for _, s := range got { - assert.Equal(t, subscription.SubscribedState, s.State()) - } -} - func TestChannelName(t *testing.T) { assert.Equal(t, "market.BTC-USD.kline", channelName(&subscription.Subscription{Channel: subscription.CandlesChannel}, btcusdPair)) assert.Equal(t, "trade.clearing#*#1", channelName(&subscription.Subscription{Channel: subscription.MyTradesChannel}, btcusdPair)) diff --git a/exchanges/huobi/huobi_websocket.go b/exchanges/huobi/huobi_websocket.go index 9a34235d657..b8f04c04006 100644 --- a/exchanges/huobi/huobi_websocket.go +++ b/exchanges/huobi/huobi_websocket.go @@ -20,6 +20,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/encoding/json" "github.com/thrasher-corp/gocryptotrader/exchange/accounts" "github.com/thrasher-corp/gocryptotrader/exchange/websocket" + exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" @@ -74,66 +75,59 @@ var subscriptionNames = map[string]string{ subscription.MyAccountChannel: wsMyAccountChannel, } -// WsConnect initiates a new websocket connection -func (e *Exchange) WsConnect() error { - ctx := context.TODO() - if !e.Websocket.IsEnabled() || !e.IsEnabled() { - return websocket.ErrWebsocketNotEnabled +func (e *Exchange) wsConnect(ctx context.Context, conn websocket.Connection) error { + return conn.Dial(ctx, &gws.Dialer{}, http.Header{}, nil) +} + +func (e *Exchange) wsAuth(ctx context.Context, conn websocket.Connection) error { + authURL, err := e.API.Endpoints.GetURL(exchange.WebsocketSpotSupplementary) + if err != nil { + authURL = wsSpotURL + wsPrivatePath } - if err := e.Websocket.Conn.Dial(ctx, &gws.Dialer{}, http.Header{}, nil); err != nil { - return err + if conn.GetURL() != authURL { + return nil } - - e.Websocket.Wg.Add(1) - go e.wsReadMsgs(ctx, e.Websocket.Conn) - - if e.IsWebsocketAuthenticationSupported() { - if err := e.wsAuthConnect(ctx); err != nil { - e.Websocket.SetCanUseAuthenticatedEndpoints(false) - return fmt.Errorf("error authenticating websocket: %w", err) - } - e.Websocket.SetCanUseAuthenticatedEndpoints(true) - e.Websocket.Wg.Add(1) - go e.wsReadMsgs(ctx, e.Websocket.AuthConn) + if err := e.wsLogin(ctx, conn); err != nil { + e.Websocket.SetCanUseAuthenticatedEndpoints(false) + return fmt.Errorf("error authenticating websocket: %w", err) } - + e.Websocket.SetCanUseAuthenticatedEndpoints(true) return nil } -// wsReadMsgs reads and processes messages from a websocket connection -func (e *Exchange) wsReadMsgs(ctx context.Context, s websocket.Connection) { - defer e.Websocket.Wg.Done() - for { - msg := s.ReadMessage() - if msg.Raw == nil { - return - } +func (e *Exchange) generatePublicSubscriptions() (subscription.List, error) { + subs, err := e.generateSubscriptions() + if err != nil { + return nil, err + } + return subs.Public(), nil +} - if err := e.wsHandleData(ctx, msg.Raw); err != nil { - if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { - log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err) - } - } +func (e *Exchange) generatePrivateSubscriptions() (subscription.List, error) { + subs, err := e.generateSubscriptions() + if err != nil { + return nil, err } + return subs.Private(), nil } -func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { +func (e *Exchange) wsHandleData(ctx context.Context, conn websocket.Connection, respRaw []byte) error { if id, err := jsonparser.GetString(respRaw, "id"); err == nil { - if e.Websocket.Match.IncomingWithData(id, respRaw) { + if conn.IncomingWithData(id, respRaw) { return nil } } if pingValue, err := jsonparser.GetInt(respRaw, "ping"); err == nil { - return e.wsHandleV1ping(ctx, int(pingValue)) + return e.wsHandleV1ping(ctx, conn, int(pingValue)) } if action, err := jsonparser.GetString(respRaw, "action"); err == nil { switch action { case "ping": - return e.wsHandleV2ping(ctx, respRaw) + return e.wsHandleV2ping(ctx, conn, respRaw) case wsSubOp, wsUnsubOp: - return e.wsHandleV2subResp(action, respRaw) + return e.wsHandleV2subResp(conn, action, respRaw) } } @@ -146,7 +140,7 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { if s == nil { return fmt.Errorf("%w: %q", subscription.ErrNotFound, ch) } - return e.wsHandleChannelMsgs(ctx, s, respRaw) + return e.wsHandleChannelMsgs(ctx, conn, s, respRaw) } return e.Websocket.DataHandler.Send(ctx, websocket.UnhandledMessageWarning{ @@ -155,33 +149,33 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { } // wsHandleV1ping handles v1 style pings, currently only used with public connections -func (e *Exchange) wsHandleV1ping(ctx context.Context, pingValue int) error { - if err := e.Websocket.Conn.SendJSONMessage(ctx, request.Unset, json.RawMessage(`{"pong":`+strconv.Itoa(pingValue)+`}`)); err != nil { +func (e *Exchange) wsHandleV1ping(ctx context.Context, conn websocket.Connection, pingValue int) error { + if err := conn.SendJSONMessage(ctx, request.Unset, json.RawMessage(`{"pong":`+strconv.Itoa(pingValue)+`}`)); err != nil { return fmt.Errorf("error sending pong response: %w", err) } return nil } // wsHandleV2ping handles v2 style pings, currently only used with private connections -func (e *Exchange) wsHandleV2ping(ctx context.Context, respRaw []byte) error { +func (e *Exchange) wsHandleV2ping(ctx context.Context, conn websocket.Connection, respRaw []byte) error { ts, err := jsonparser.GetInt(respRaw, "data", "ts") if err != nil { return fmt.Errorf("error getting ts from auth ping: %w", err) } - if err := e.Websocket.AuthConn.SendJSONMessage(ctx, request.Unset, json.RawMessage(`{"action":"pong","data":{"ts":`+strconv.FormatInt(ts, 10)+`}}`)); err != nil { + if err := conn.SendJSONMessage(ctx, request.Unset, json.RawMessage(`{"action":"pong","data":{"ts":`+strconv.FormatInt(ts, 10)+`}}`)); err != nil { return fmt.Errorf("error sending auth pong response: %w", err) } return nil } -func (e *Exchange) wsHandleV2subResp(action string, respRaw []byte) error { +func (e *Exchange) wsHandleV2subResp(conn websocket.Connection, action string, respRaw []byte) error { if ch, err := jsonparser.GetString(respRaw, "ch"); err == nil { - return e.Websocket.Match.RequireMatchWithData(action+":"+ch, respRaw) + return conn.RequireMatchWithData(action+":"+ch, respRaw) } return nil } -func (e *Exchange) wsHandleChannelMsgs(ctx context.Context, s *subscription.Subscription, respRaw []byte) error { +func (e *Exchange) wsHandleChannelMsgs(ctx context.Context, _ websocket.Connection, s *subscription.Subscription, respRaw []byte) error { switch s.Channel { case subscription.TickerChannel: return e.wsHandleTickerMsg(ctx, s, respRaw) @@ -497,32 +491,23 @@ func (e *Exchange) GetSubscriptionTemplate(_ *subscription.Subscription) (*templ }).Parse(subTplText) } -// Subscribe sends a websocket message to receive data from the channel -func (e *Exchange) Subscribe(subs subscription.List) error { - ctx := context.TODO() - subs, errs := subs.ExpandTemplates(e) - return common.AppendError(errs, e.ParallelChanOp(ctx, subs, func(ctx context.Context, l subscription.List) error { return e.manageSubs(ctx, wsSubOp, l) }, 1)) +func (e *Exchange) subscribeForConnection(ctx context.Context, conn websocket.Connection, subs subscription.List) error { + return common.AppendError(nil, e.ParallelChanOp(ctx, subs, func(ctx context.Context, l subscription.List) error { return e.manageSubs(ctx, conn, wsSubOp, l) }, 1)) } -// Unsubscribe sends a websocket message to stop receiving data from the channel -func (e *Exchange) Unsubscribe(subs subscription.List) error { - ctx := context.TODO() - subs, errs := subs.ExpandTemplates(e) - return common.AppendError(errs, e.ParallelChanOp(ctx, subs, func(ctx context.Context, l subscription.List) error { return e.manageSubs(ctx, wsUnsubOp, l) }, 1)) +func (e *Exchange) unsubscribeForConnection(ctx context.Context, conn websocket.Connection, subs subscription.List) error { + return common.AppendError(nil, e.ParallelChanOp(ctx, subs, func(ctx context.Context, l subscription.List) error { return e.manageSubs(ctx, conn, wsUnsubOp, l) }, 1)) } -func (e *Exchange) manageSubs(ctx context.Context, op string, subs subscription.List) error { +func (e *Exchange) manageSubs(ctx context.Context, conn websocket.Connection, op string, subs subscription.List) error { if len(subs) != 1 { return subscription.ErrBatchingNotSupported } s := subs[0] - var c websocket.Connection var req any if s.Authenticated { - c = e.Websocket.AuthConn req = wsReq{Action: op, Channel: s.QualifiedChannel} } else { - c = e.Websocket.Conn if op == wsSubOp { // Set the id to the channel so that V1 errors can make it back to us req = wsSubReq{ID: wsSubOp + ":" + s.QualifiedChannel, Sub: s.QualifiedChannel} @@ -532,17 +517,17 @@ func (e *Exchange) manageSubs(ctx context.Context, op string, subs subscription. } if op == wsSubOp { s.SetKey(s.QualifiedChannel) - if err := e.Websocket.AddSubscriptions(c, s); err != nil { + if err := e.Websocket.AddSubscriptions(conn, s); err != nil { return fmt.Errorf("%w: %s; error: %w", websocket.ErrSubscriptionFailure, s, err) } } - respRaw, err := c.SendMessageReturnResponse(ctx, request.Unset, wsSubOp+":"+s.QualifiedChannel, req) + respRaw, err := conn.SendMessageReturnResponse(ctx, request.Unset, wsSubOp+":"+s.QualifiedChannel, req) if err == nil { err = getErrResp(respRaw) } if err != nil { if op == wsSubOp { - _ = e.Websocket.RemoveSubscriptions(c, s) + _ = e.Websocket.RemoveSubscriptions(conn, s) } return fmt.Errorf("%s: %w", s, err) } @@ -552,7 +537,7 @@ func (e *Exchange) manageSubs(ctx context.Context, op string, subs subscription. log.Debugf(log.ExchangeSys, "%s Subscribed to %s", e.Name, s) } } else { - err = e.Websocket.RemoveSubscriptions(c, s) + err = e.Websocket.RemoveSubscriptions(conn, s) } return err } @@ -567,17 +552,7 @@ func (e *Exchange) wsGenerateSignature(creds *accounts.Credentials, timestamp st return crypto.GetHMAC(crypto.HashSHA256, []byte(payload), []byte(creds.Secret)) } -func (e *Exchange) wsAuthConnect(ctx context.Context) error { - if err := e.Websocket.AuthConn.Dial(ctx, &gws.Dialer{}, http.Header{}, nil); err != nil { - return fmt.Errorf("authenticated dial failed: %w", err) - } - if err := e.wsLogin(ctx); err != nil { - return fmt.Errorf("authentication failed: %w", err) - } - return nil -} - -func (e *Exchange) wsLogin(ctx context.Context) error { +func (e *Exchange) wsLogin(ctx context.Context, conn websocket.Connection) error { creds, err := e.GetCredentials(ctx) if err != nil { return err @@ -600,11 +575,10 @@ func (e *Exchange) wsLogin(ctx context.Context) error { Timestamp: ts, }, } - c := e.Websocket.AuthConn - if err := c.SendJSONMessage(ctx, request.Unset, req); err != nil { + if err := conn.SendJSONMessage(ctx, request.Unset, req); err != nil { return err } - resp := c.ReadMessage() + resp := conn.ReadMessage() if resp.Raw == nil { return &gws.CloseError{Code: gws.CloseAbnormalClosure} } diff --git a/exchanges/huobi/huobi_wrapper.go b/exchanges/huobi/huobi_wrapper.go index 6df864852b6..18d473af9c6 100644 --- a/exchanges/huobi/huobi_wrapper.go +++ b/exchanges/huobi/huobi_wrapper.go @@ -160,10 +160,11 @@ func (e *Exchange) SetDefaults() { } e.API.Endpoints = e.NewEndpoints() err = e.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ - exchange.RestSpot: huobiAPIURL, - exchange.RestFutures: huobiFuturesURL, - exchange.RestCoinMargined: huobiFuturesURL, - exchange.WebsocketSpot: wsSpotURL + wsPublicPath, + exchange.RestSpot: huobiAPIURL, + exchange.RestFutures: huobiFuturesURL, + exchange.RestCoinMargined: huobiFuturesURL, + exchange.WebsocketSpot: wsSpotURL + wsPublicPath, + exchange.WebsocketSpotSupplementary: wsSpotURL + wsPrivatePath, }) if err != nil { log.Errorln(log.ExchangeSys, err) @@ -206,34 +207,48 @@ func (e *Exchange) Setup(exch *config.Exchange) error { } err = e.Websocket.Setup(&websocket.ManagerSetup{ - ExchangeConfig: exch, - DefaultURL: wsSpotURL + wsPublicPath, - RunningURL: wsRunningURL, - Connector: e.WsConnect, - Subscriber: e.Subscribe, - Unsubscriber: e.Unsubscribe, - GenerateSubscriptions: e.generateSubscriptions, - Features: &e.Features.Supports.WebsocketCapabilities, + ExchangeConfig: exch, + UseMultiConnectionManagement: true, + Features: &e.Features.Supports.WebsocketCapabilities, }) if err != nil { return err } err = e.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ - RateLimit: request.NewWeightedRateLimitByDuration(20 * time.Millisecond), - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + URL: wsRunningURL, + Connector: e.wsConnect, + Subscriber: e.subscribeForConnection, + Unsubscriber: e.unsubscribeForConnection, + GenerateSubscriptions: e.generatePublicSubscriptions, + Handler: e.wsHandleData, + RateLimit: request.NewWeightedRateLimitByDuration(20 * time.Millisecond), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + MessageFilter: wsRunningURL, }) if err != nil { return err } + wsRunningAuthURL, err := e.API.Endpoints.GetURL(exchange.WebsocketSpotSupplementary) + if err != nil { + return err + } + return e.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ - RateLimit: request.NewWeightedRateLimitByDuration(20 * time.Millisecond), - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - URL: wsSpotURL + wsPrivatePath, - Authenticated: true, + URL: wsRunningAuthURL, + Connector: e.wsConnect, + Authenticate: e.wsAuth, + Subscriber: e.subscribeForConnection, + Unsubscriber: e.unsubscribeForConnection, + GenerateSubscriptions: e.generatePrivateSubscriptions, + SubscriptionsNotRequired: true, + Handler: e.wsHandleData, + RateLimit: request.NewWeightedRateLimitByDuration(20 * time.Millisecond), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + MessageFilter: wsRunningAuthURL, }) } @@ -1707,11 +1722,6 @@ func setOrderSideStatusAndType(orderState, requestType string, orderDetail *orde } } -// AuthenticateWebsocket sends an authentication message to the websocket -func (e *Exchange) AuthenticateWebsocket(ctx context.Context) error { - return e.wsLogin(ctx) -} - // ValidateAPICredentials validates current credentials used for wrapper functionality func (e *Exchange) ValidateAPICredentials(ctx context.Context, assetType asset.Item) error { _, err := e.UpdateAccountBalances(ctx, assetType) diff --git a/exchanges/kraken/kraken_test.go b/exchanges/kraken/kraken_test.go index 79b2baba9f1..be8f5a4118f 100644 --- a/exchanges/kraken/kraken_test.go +++ b/exchanges/kraken/kraken_test.go @@ -1,9 +1,12 @@ package kraken import ( + "context" "errors" "log" + "maps" "net/http" + "net/http/httptest" "os" "strings" "sync" @@ -17,6 +20,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/core" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" + "github.com/thrasher-corp/gocryptotrader/exchange/websocket" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/fundingrate" @@ -877,258 +881,6 @@ func TestWithdrawCancel(t *testing.T) { // ---------------------------- Websocket tests ----------------------------------------- -// TestWsSubscribe tests unauthenticated websocket subscriptions -// Specifically looking to ensure multiple errors are collected and returned and ws.Subscriptions Added/Removed in cases of: -// single pass, single fail, mixed fail, multiple pass, all fail -// No objection to this becoming a fixture test, so long as it integrates through Un/Subscribe roundtrip -func TestWsSubscribe(t *testing.T) { - e := new(Exchange) - require.NoError(t, testexch.Setup(e), "Setup Instance must not error") - testexch.SetupWs(t, e) - - for _, enabled := range []bool{false, true} { - require.NoError(t, e.SetPairs(currency.Pairs{ - spotTestPair, - currency.NewPairWithDelimiter("ETH", "USD", "/"), - currency.NewPairWithDelimiter("LTC", "ETH", "/"), - currency.NewPairWithDelimiter("ETH", "XBT", "/"), - // Enable pairs that won't error locally, so we get upstream errors to test error combinations - currency.NewPairWithDelimiter("DWARF", "HOBBIT", "/"), - currency.NewPairWithDelimiter("DWARF", "GOBLIN", "/"), - currency.NewPairWithDelimiter("DWARF", "ELF", "/"), - }, asset.Spot, enabled), "SetPairs must not error") - } - - err := e.Subscribe(subscription.List{{Asset: asset.Spot, Channel: subscription.TickerChannel, Pairs: currency.Pairs{spotTestPair}}}) - require.NoError(t, err, "Simple subscription must not error") - subs := e.Websocket.GetSubscriptions() - require.Len(t, subs, 1, "Should add 1 Subscription") - assert.Equal(t, subscription.SubscribedState, subs[0].State(), "Subscription should be subscribed state") - - err = e.Subscribe(subscription.List{{Asset: asset.Spot, Channel: subscription.TickerChannel, Pairs: currency.Pairs{spotTestPair}}}) - assert.ErrorIs(t, err, subscription.ErrDuplicate, "Resubscribing to the same channel should error with SubscribedAlready") - subs = e.Websocket.GetSubscriptions() - require.Len(t, subs, 1, "Should not add a subscription on error") - assert.Equal(t, subscription.SubscribedState, subs[0].State(), "Existing subscription state should not change") - - err = e.Subscribe(subscription.List{{Asset: asset.Spot, Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewPairWithDelimiter("DWARF", "HOBBIT", "/")}}}) - assert.ErrorContains(t, err, "Currency pair not supported; Channel: ticker Pairs: DWARF/HOBBIT", "Subscribing to an invalid pair should error correctly") - require.Len(t, e.Websocket.GetSubscriptions(), 1, "Should not add a subscription on error") - - // Mix success and failure - err = e.Subscribe(subscription.List{ - {Asset: asset.Spot, Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewPairWithDelimiter("ETH", "USD", "/")}}, - {Asset: asset.Spot, Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewPairWithDelimiter("DWARF", "HOBBIT", "/")}}, - {Asset: asset.Spot, Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewPairWithDelimiter("DWARF", "ELF", "/")}}, - }) - assert.ErrorContains(t, err, "Currency pair not supported; Channel: ticker Pairs:", "Subscribing to an invalid pair should error correctly") - assert.ErrorContains(t, err, "DWARF/HOBBIT", "Subscribing to an invalid pair should error correctly") - assert.ErrorContains(t, err, "DWARF/ELF", "Subscribing to an invalid pair should error correctly") - require.Len(t, e.Websocket.GetSubscriptions(), 2, "Should have 2 subscriptions after mixed success/failures") - - // Just failures - err = e.Subscribe(subscription.List{ - {Asset: asset.Spot, Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewPairWithDelimiter("DWARF", "HOBBIT", "/")}}, - {Asset: asset.Spot, Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewPairWithDelimiter("DWARF", "GOBLIN", "/")}}, - }) - assert.ErrorContains(t, err, "Currency pair not supported; Channel: ticker Pairs:", "Subscribing to an invalid pair should error correctly") - assert.ErrorContains(t, err, "DWARF/HOBBIT", "Subscribing to an invalid pair should error correctly") - assert.ErrorContains(t, err, "DWARF/GOBLIN", "Subscribing to an invalid pair should error correctly") - require.Len(t, e.Websocket.GetSubscriptions(), 2, "Should have 2 subscriptions after mixed success/failures") - - // Just success - err = e.Subscribe(subscription.List{ - {Asset: asset.Spot, Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewPairWithDelimiter("ETH", "XBT", "/")}}, - {Asset: asset.Spot, Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewPairWithDelimiter("LTC", "ETH", "/")}}, - }) - assert.NoError(t, err, "Multiple successful subscriptions should not error") - - subs = e.Websocket.GetSubscriptions() - assert.Len(t, subs, 4, "Should have correct number of subscriptions") - - err = e.Unsubscribe(subs[:1]) - assert.NoError(t, err, "Simple Unsubscribe should succeed") - assert.Len(t, e.Websocket.GetSubscriptions(), 3, "Should have removed 1 channel") - - err = e.Unsubscribe(subscription.List{{Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewPairWithDelimiter("DWARF", "WIZARD", "/")}, Key: 1337}}) - assert.ErrorIs(t, err, subscription.ErrNotFound, "Simple failing Unsubscribe should error NotFound") - assert.ErrorContains(t, err, "DWARF/WIZARD", "Unsubscribing from an invalid pair should error correctly") - assert.Len(t, e.Websocket.GetSubscriptions(), 3, "Should not have removed any channels") - - err = e.Unsubscribe(subscription.List{ - subs[1], - {Asset: asset.Spot, Channel: subscription.TickerChannel, Pairs: currency.Pairs{currency.NewPairWithDelimiter("DWARF", "EAGLE", "/")}, Key: 1338}, - }) - assert.ErrorIs(t, err, subscription.ErrNotFound, "Mixed failing Unsubscribe should error NotFound") - assert.ErrorContains(t, err, "Channel: ticker Pairs: DWARF/EAGLE", "Unsubscribing from an invalid pair should error correctly") - - subs = e.Websocket.GetSubscriptions() - assert.Len(t, subs, 2, "Should have removed only 1 more channel") - - err = e.Unsubscribe(subs) - assert.NoError(t, err, "Unsubscribe multiple passing subscriptions should not error") - assert.Empty(t, e.Websocket.GetSubscriptions(), "Should have successfully removed all channels") - - for _, c := range []string{"ohlc", "ohlc-5"} { - err = e.Subscribe(subscription.List{{ - Asset: asset.Spot, - Channel: c, - Pairs: currency.Pairs{spotTestPair}, - }}) - assert.ErrorIs(t, err, subscription.ErrUseConstChannelName, "Must error when trying to use a private channel name") - assert.ErrorContains(t, err, c+" => subscription.CandlesChannel", "Must error when trying to use a private channel name") - } -} - -// TestWsResubscribe tests websocket resubscription -func TestWsResubscribe(t *testing.T) { - e := new(Exchange) - require.NoError(t, testexch.Setup(e), "TestInstance must not error") - testexch.SetupWs(t, e) - - err := e.Subscribe(subscription.List{{Asset: asset.Spot, Channel: subscription.OrderbookChannel, Levels: 1000}}) - require.NoError(t, err, "Subscribe must not error") - subs := e.Websocket.GetSubscriptions() - require.Len(t, subs, 1, "Should add 1 Subscription") - require.Equal(t, subscription.SubscribedState, subs[0].State(), "Subscription must be in a subscribed state") - - require.Eventually(t, func() bool { - b, e2 := e.Websocket.Orderbook.GetOrderbook(spotTestPair, asset.Spot) - if e2 == nil { - return !b.LastUpdated.IsZero() - } - return false - }, time.Second*4, time.Millisecond*10, "orderbook must start streaming") - - // Set the state to Unsub so we definitely know Resub worked - err = subs[0].SetState(subscription.UnsubscribingState) - require.NoError(t, err) - - err = e.Websocket.ResubscribeToChannel(t.Context(), e.Websocket.Conn, subs[0]) - require.NoError(t, err, "Resubscribe must not error") - require.Equal(t, subscription.SubscribedState, subs[0].State(), "subscription must be subscribed again") -} - -// TestWsOrderbookSub tests orderbook subscriptions for MaxDepth params -func TestWsOrderbookSub(t *testing.T) { - t.Parallel() - - e := new(Exchange) - require.NoError(t, testexch.Setup(e), "Setup Instance must not error") - testexch.SetupWs(t, e) - - err := e.Subscribe(subscription.List{{ - Asset: asset.Spot, - Channel: subscription.OrderbookChannel, - Pairs: currency.Pairs{spotTestPair}, - Levels: 25, - }}) - require.NoError(t, err, "Simple subscription must not error") - - subs := e.Websocket.GetSubscriptions() - require.Equal(t, 1, len(subs), "Must have 1 subscription channel") - - err = e.Unsubscribe(subs) - assert.NoError(t, err, "Unsubscribe should not error") - assert.Empty(t, e.Websocket.GetSubscriptions(), "Should have successfully removed all channels") - - err = e.Subscribe(subscription.List{{ - Asset: asset.Spot, - Channel: subscription.OrderbookChannel, - Pairs: currency.Pairs{spotTestPair}, - Levels: 42, - }}) - assert.ErrorContains(t, err, "Subscription depth not supported", "Bad subscription should error about depth") -} - -// TestWsCandlesSub tests candles subscription for Timeframe params -func TestWsCandlesSub(t *testing.T) { - t.Parallel() - - e := new(Exchange) - require.NoError(t, testexch.Setup(e), "Setup Instance must not error") - testexch.SetupWs(t, e) - - err := e.Subscribe(subscription.List{{ - Asset: asset.Spot, - Channel: subscription.CandlesChannel, - Pairs: currency.Pairs{spotTestPair}, - Interval: kline.OneHour, - }}) - require.NoError(t, err, "Simple subscription must not error") - - subs := e.Websocket.GetSubscriptions() - require.Equal(t, 1, len(subs), "Should add 1 Subscription") - - err = e.Unsubscribe(subs) - assert.NoError(t, err, "Unsubscribe should not error") - assert.Empty(t, e.Websocket.GetSubscriptions(), "Should have successfully removed all channels") - - err = e.Subscribe(subscription.List{{ - Asset: asset.Spot, - Channel: subscription.CandlesChannel, - Pairs: currency.Pairs{spotTestPair}, - Interval: kline.Interval(time.Minute * time.Duration(127)), - }}) - assert.ErrorContains(t, err, "Subscription ohlc interval not supported", "Bad subscription should error about interval") -} - -func TestWsProcessCandleIntervalMapping(t *testing.T) { - t.Parallel() - ex := new(Exchange) - require.NoError(t, testexch.Setup(ex), "Setup Instance must not error") - - err := ex.wsProcessCandle(t.Context(), - "ohlc-5", - json.RawMessage(`[1542057314,1542057360,3586.7,3586.7,3586.6,3586.6,3586.68,0.03373,2]`), - currency.NewPairWithDelimiter("XBT", "USD", "/")) - require.NoError(t, err) - - select { - case msg := <-ex.Websocket.DataHandler.C: - got, ok := msg.Data.(kline.Item) - require.True(t, ok, "expected kline item") - assert.Equal(t, kline.Item{ - Asset: asset.Spot, - Pair: currency.NewPairWithDelimiter("XBT", "USD", "/"), - Exchange: ex.Name, - Interval: kline.FiveMin, - Candles: []kline.Candle{{ - Time: time.Unix(1542057314, 0), - Open: 3586.7, - High: 3586.7, - Low: 3586.6, - Close: 3586.6, - Volume: 0.03373, - QuoteVolume: 120.97871640000001, - }}, - }, got) - default: - require.Fail(t, "expected websocket candle payload") - } -} - -// TestWsOwnTradesSub tests the authenticated WS subscription channel for trades -func TestWsOwnTradesSub(t *testing.T) { - t.Parallel() - - sharedtestvalues.SkipTestIfCredentialsUnset(t, e) - - e := new(Exchange) - require.NoError(t, testexch.Setup(e), "Setup Instance must not error") - testexch.SetupWs(t, e) - - err := e.Subscribe(subscription.List{{Channel: subscription.MyTradesChannel, Authenticated: true}}) - assert.NoError(t, err, "Subsrcibing to ownTrades should not error") - - subs := e.Websocket.GetSubscriptions() - assert.Len(t, subs, 1, "Should add 1 Subscription") - - err = e.Unsubscribe(subs) - assert.NoError(t, err, "Unsubscribing an auth channel should not error") - assert.Empty(t, e.Websocket.GetSubscriptions(), "Should have successfully removed channel") -} - // TestGenerateSubscriptions tests the subscriptions generated from configuration func TestGenerateSubscriptions(t *testing.T) { t.Parallel() @@ -1174,11 +926,81 @@ func TestGetWSToken(t *testing.T) { assert.NotEmpty(t, resp, "Token should not be empty") } +func TestSubscribeForConnection(t *testing.T) { + t.Parallel() + + k := mockWsInstance(t, curryWsMockUpgrader(t, mockWsServer)) + + wsRunningURL, err := k.API.Endpoints.GetURL(exchange.WebsocketSpotSupplementary) + require.NoError(t, err, "GetURL must not error") + conn, err := k.Websocket.GetConnection(wsRunningURL) + require.NoError(t, err, "GetConnection must not error") + + subs := subscription.List{ + { + Asset: asset.Spot, + Channel: subscription.OrderbookChannel, + QualifiedChannel: channelName(&subscription.Subscription{Channel: subscription.OrderbookChannel}), + Pairs: currency.Pairs{spotTestPair}, + Levels: 1000, + }, + { + Asset: asset.Spot, + Channel: subscription.OrderbookChannel, + QualifiedChannel: channelName(&subscription.Subscription{Channel: subscription.OrderbookChannel}), + Pairs: currency.Pairs{currency.NewPair(currency.ETH, currency.USD)}, + Levels: 1000, + }, + } + + require.NoError(t, k.subscribeForConnection(t.Context(), conn, subs), "subscribeForConnection must not error") + + for i := range subs { + s := subs[i] + require.Eventually(t, func() bool { + got := k.Websocket.GetSubscription(s) + return got != nil && got.State() == subscription.SubscribedState + }, time.Second, 10*time.Millisecond, "subscription must transition to subscribed state") + } +} + +func TestSubscribeForConnectionResubscribeAndUnsubscribe(t *testing.T) { + t.Parallel() + + k := mockWsInstance(t, curryWsMockUpgrader(t, mockWsServer)) + + wsRunningURL, err := k.API.Endpoints.GetURL(exchange.WebsocketSpotSupplementary) + require.NoError(t, err, "GetURL must not error") + conn, err := k.Websocket.GetConnection(wsRunningURL) + require.NoError(t, err, "GetConnection must not error") + + sub := &subscription.Subscription{ + Asset: asset.Spot, + Channel: subscription.TickerChannel, + QualifiedChannel: channelName(&subscription.Subscription{Channel: subscription.TickerChannel}), + Pairs: currency.Pairs{spotTestPair}, + } + + require.NoError(t, k.Websocket.AddSubscriptions(conn, sub), "AddSubscriptions must not error") + require.NoError(t, sub.SetState(subscription.ResubscribingState), "SetState must not error") + + require.NoError(t, k.subscribeForConnection(t.Context(), conn, subscription.List{sub}), "subscribeForConnection must not error") + require.Eventually(t, func() bool { + got := k.Websocket.GetSubscription(sub) + return got != nil && got.State() == subscription.SubscribedState + }, time.Second, 10*time.Millisecond, "resubscribing subscription must transition to subscribed state") + + require.NoError(t, k.unsubscribeForConnection(t.Context(), conn, subscription.List{sub}), "unsubscribeForConnection must not error") + require.Eventually(t, func() bool { + return k.Websocket.GetSubscription(sub) == nil + }, time.Second, 10*time.Millisecond, "subscription must be removed after unsubscribe") +} + // TestWsAddOrder exercises roundtrip of wsAddOrder; See also: mockWsAddOrder func TestWsAddOrder(t *testing.T) { t.Parallel() - k := testexch.MockWsInstance[Exchange](t, curryWsMockUpgrader(t, mockWsServer)) + k := mockWsInstance(t, curryWsMockUpgrader(t, mockWsServer)) require.True(t, k.IsWebsocketAuthenticationSupported(), "WS must be authenticated") id, err := k.wsAddOrder(t.Context(), &WsAddOrderRequest{ OrderType: order.Limit.Lower(), @@ -1194,7 +1016,7 @@ func TestWsAddOrder(t *testing.T) { func TestWsCancelOrders(t *testing.T) { t.Parallel() - k := testexch.MockWsInstance[Exchange](t, curryWsMockUpgrader(t, mockWsServer)) + k := mockWsInstance(t, curryWsMockUpgrader(t, mockWsServer)) require.True(t, k.IsWebsocketAuthenticationSupported(), "WS must be authenticated") err := k.wsCancelOrders(t.Context(), []string{"RABBIT", "BATFISH", "SQUIRREL", "CATFISH", "MOUSE"}) @@ -1219,7 +1041,7 @@ func TestWsHandleData(t *testing.T) { e := new(Exchange) require.NoError(t, testexch.Setup(e), "Setup Instance must not error") for _, l := range []int{10, 100} { - err := e.Websocket.AddSuccessfulSubscriptions(e.Websocket.Conn, &subscription.Subscription{ + err := e.Websocket.AddSuccessfulSubscriptions(nil, &subscription.Subscription{ Channel: subscription.OrderbookChannel, Pairs: currency.Pairs{spotTestPair}, Asset: asset.Spot, @@ -1227,7 +1049,8 @@ func TestWsHandleData(t *testing.T) { }) require.NoError(t, err, "AddSuccessfulSubscriptions must not error") } - testexch.FixtureToDataHandler(t, "testdata/wsHandleData.json", e.wsHandleData) + conn := testexch.GetMockConn(t, e, "") + testexch.FixtureToDataHandler(t, "testdata/wsHandleData.json", func(ctx context.Context, b []byte) error { return e.wsHandleData(ctx, conn, b) }) } func TestWSProcessTrades(t *testing.T) { @@ -1235,9 +1058,10 @@ func TestWSProcessTrades(t *testing.T) { e := new(Exchange) require.NoError(t, testexch.Setup(e), "Test instance Setup must not error") - err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{spotTestPair}, Channel: subscription.AllTradesChannel, Key: 18788}) + err := e.Websocket.AddSubscriptions(nil, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{spotTestPair}, Channel: subscription.AllTradesChannel, Key: 18788}) require.NoError(t, err, "AddSubscriptions must not error") - testexch.FixtureToDataHandler(t, "testdata/wsAllTrades.json", e.wsHandleData) + conn := testexch.GetMockConn(t, e, "") + testexch.FixtureToDataHandler(t, "testdata/wsAllTrades.json", func(ctx context.Context, b []byte) error { return e.wsHandleData(ctx, conn, b) }) e.Websocket.DataHandler.Close() invalid := []any{"trades", []any{[]any{"95873.80000", "0.00051182", "1708731380.3791859"}}} @@ -1273,7 +1097,8 @@ func TestWsOpenOrders(t *testing.T) { e := new(Exchange) require.NoError(t, testexch.Setup(e), "Test instance Setup must not error") testexch.UpdatePairsOnce(t, e) - testexch.FixtureToDataHandler(t, "testdata/wsOpenTrades.json", e.wsHandleData) + conn := testexch.GetMockConn(t, e, "") + testexch.FixtureToDataHandler(t, "testdata/wsOpenTrades.json", func(ctx context.Context, b []byte) error { return e.wsHandleData(ctx, conn, b) }) e.Websocket.DataHandler.Close() assert.Len(t, e.Websocket.DataHandler.C, 7, "Should see 7 orders") for resp := range e.Websocket.DataHandler.C { @@ -1487,7 +1312,7 @@ func TestWsOrderbookMax10Depth(t *testing.T) { currency.NewPairWithDelimiter("GST", "EUR", "/"), } for _, p := range pairs { - err := e.Websocket.AddSuccessfulSubscriptions(e.Websocket.Conn, &subscription.Subscription{ + err := e.Websocket.AddSuccessfulSubscriptions(nil, &subscription.Subscription{ Channel: subscription.OrderbookChannel, Pairs: currency.Pairs{p}, Asset: asset.Spot, @@ -1495,14 +1320,15 @@ func TestWsOrderbookMax10Depth(t *testing.T) { }) require.NoError(t, err, "AddSuccessfulSubscriptions must not error") } + conn := testexch.GetMockConn(t, e, "") for x := range websocketXDGUSDOrderbookUpdates { - err := e.wsHandleData(t.Context(), []byte(websocketXDGUSDOrderbookUpdates[x])) + err := e.wsHandleData(t.Context(), conn, []byte(websocketXDGUSDOrderbookUpdates[x])) require.NoError(t, err, "wsHandleData must not error") } for x := range websocketLUNAEUROrderbookUpdates { - err := e.wsHandleData(t.Context(), []byte(websocketLUNAEUROrderbookUpdates[x])) + err := e.wsHandleData(t.Context(), conn, []byte(websocketLUNAEUROrderbookUpdates[x])) // TODO: Known issue with LUNA pairs and big number float precision // storage and checksum calc. Might need to store raw strings as fields // in the orderbook.Level struct. @@ -1514,7 +1340,7 @@ func TestWsOrderbookMax10Depth(t *testing.T) { // This has less than 10 bids and still needs a checksum calc. for x := range websocketGSTEUROrderbookUpdates { - err := e.wsHandleData(t.Context(), []byte(websocketGSTEUROrderbookUpdates[x])) + err := e.wsHandleData(t.Context(), conn, []byte(websocketGSTEUROrderbookUpdates[x])) require.NoError(t, err, "wsHandleData must not error") } } @@ -1626,6 +1452,36 @@ func curryWsMockUpgrader(tb testing.TB, h mockws.WsMockFunc) http.HandlerFunc { } } +func mockWsInstance(tb testing.TB, h http.HandlerFunc) *Exchange { + tb.Helper() + + e := new(Exchange) + require.NoError(tb, testexch.Setup(e), "Test exchange Setup must not error") + + s := httptest.NewServer(h) + tb.Cleanup(s.Close) + wsURL := "ws" + strings.TrimPrefix(s.URL, "http") + + b := e.GetBase() + cfg := *b.Config + cfg.API.Endpoints = make(map[string]string, len(b.Config.API.Endpoints)) + maps.Copy(cfg.API.Endpoints, b.Config.API.Endpoints) + cfg.API.AuthenticatedWebsocketSupport = true + cfg.API.Endpoints["RestSpotURL"] = s.URL + cfg.API.Endpoints["WebsocketSpotURL"] = wsURL + "/public" + cfg.API.Endpoints["WebsocketSpotSupplementaryURL"] = wsURL + "/private" + e.Websocket = websocket.NewManager() + require.NoError(tb, e.Setup(&cfg), "Setup must not error") + + b = e.GetBase() + b.SkipAuthCheck = true + b.API.AuthenticatedWebsocketSupport = true + b.Features.Subscriptions = subscription.List{} + b.Websocket.GenerateSubs = func() (subscription.List, error) { return subscription.List{}, nil } + require.NoError(tb, b.Websocket.Connect(context.TODO()), "Connect must not error") + return e +} + func TestGetCurrencyTradeURL(t *testing.T) { t.Parallel() testexch.UpdatePairsOnce(t, e) @@ -1634,7 +1490,7 @@ func TestGetCurrencyTradeURL(t *testing.T) { if len(pairs) == 0 { continue } - require.NoErrorf(t, err, "cannot get pairs for %s", a) + require.NoErrorf(t, err, "GetPairs must not error for asset %s", a) resp, err := e.GetCurrencyTradeURL(t.Context(), a, pairs[0]) if a != asset.Spot && a != asset.Futures { assert.ErrorIs(t, err, asset.ErrNotSupported) diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index 6e6ca6683a3..db950e826c3 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -18,6 +18,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" "github.com/thrasher-corp/gocryptotrader/exchange/websocket" + exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" @@ -92,62 +93,41 @@ var defaultSubscriptions = subscription.List{ {Enabled: true, Channel: subscription.MyTradesChannel, Authenticated: true}, } -// WsConnect initiates a websocket connection -func (e *Exchange) WsConnect() error { - ctx := context.TODO() - if !e.Websocket.IsEnabled() || !e.IsEnabled() { - return websocket.ErrWebsocketNotEnabled +func (e *Exchange) wsConnect(ctx context.Context, conn websocket.Connection) error { + if err := conn.Dial(ctx, &gws.Dialer{}, http.Header{}, nil); err != nil { + return err } + e.startWsPingHandler(conn) + return nil +} - var dialer gws.Dialer - err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{}, nil) +func (e *Exchange) wsAuthenticate(ctx context.Context, _ websocket.Connection) error { + authToken, err := e.GetWebsocketToken(ctx) if err != nil { return err } + e.setWebsocketAuthToken(authToken) + e.Websocket.SetCanUseAuthenticatedEndpoints(true) + return nil +} - e.Websocket.Wg.Add(1) - go e.wsReadData(ctx, e.Websocket.Conn) - - if e.IsWebsocketAuthenticationSupported() { - if authToken, err := e.GetWebsocketToken(ctx); err != nil { - e.Websocket.SetCanUseAuthenticatedEndpoints(false) - log.Errorf(log.ExchangeSys, "%s - authentication failed: %v\n", e.Name, err) - } else { - if err := e.Websocket.AuthConn.Dial(ctx, &dialer, http.Header{}, nil); err != nil { - e.Websocket.SetCanUseAuthenticatedEndpoints(false) - log.Errorf(log.ExchangeSys, "%s - failed to connect to authenticated endpoint: %v\n", e.Name, err) - } else { - e.setWebsocketAuthToken(authToken) - e.Websocket.SetCanUseAuthenticatedEndpoints(true) - e.Websocket.Wg.Add(1) - go e.wsReadData(ctx, e.Websocket.AuthConn) - e.startWsPingHandler(e.Websocket.AuthConn) - } - } +func (e *Exchange) generatePublicSubscriptions() (subscription.List, error) { + subs, err := e.generateSubscriptions() + if err != nil { + return nil, err } - - e.startWsPingHandler(e.Websocket.Conn) - - return nil + return subs.Public(), nil } -// wsReadData funnels both auth and public ws data into one manageable place -func (e *Exchange) wsReadData(ctx context.Context, ws websocket.Connection) { - defer e.Websocket.Wg.Done() - for { - resp := ws.ReadMessage() - if resp.Raw == nil { - return - } - if err := e.wsHandleData(ctx, resp.Raw); err != nil { - if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil { - log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, ws.GetURL(), errSend, err) - } - } +func (e *Exchange) generatePrivateSubscriptions() (subscription.List, error) { + subs, err := e.generateSubscriptions() + if err != nil { + return nil, err } + return subs.Private(), nil } -func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { +func (e *Exchange) wsHandleData(ctx context.Context, conn websocket.Connection, respRaw []byte) error { if strings.HasPrefix(string(respRaw), "[") { var msg []json.RawMessage if err := json.Unmarshal(respRaw, &msg); err != nil { @@ -173,7 +153,7 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { pair = p } - return e.wsReadDataResponse(ctx, chanName, pair, msg) + return e.wsReadDataResponse(ctx, conn, chanName, pair, msg) } event, err := jsonparser.GetString(respRaw, "event") @@ -182,11 +162,11 @@ func (e *Exchange) wsHandleData(ctx context.Context, respRaw []byte) error { } if event == krakenWsSubscriptionStatus { // Must happen before IncomingWithData to avoid race - e.wsProcessSubStatus(respRaw) + e.wsProcessSubStatus(conn, respRaw) } reqID, err := jsonparser.GetInt(respRaw, "reqid") - if err == nil && reqID != 0 && e.Websocket.Match.IncomingWithData(reqID, respRaw) { + if err == nil && reqID != 0 && conn.IncomingWithData(reqID, respRaw) { return nil } @@ -219,7 +199,7 @@ func (e *Exchange) startWsPingHandler(conn websocket.Connection) { } // wsReadDataResponse classifies the WS response and sends to appropriate handler -func (e *Exchange) wsReadDataResponse(ctx context.Context, c string, pair currency.Pair, response []json.RawMessage) error { +func (e *Exchange) wsReadDataResponse(ctx context.Context, conn websocket.Connection, c string, pair currency.Pair, response []json.RawMessage) error { switch c { case krakenWsTicker: return e.wsProcessTickers(ctx, response[1], pair) @@ -238,7 +218,7 @@ func (e *Exchange) wsReadDataResponse(ctx context.Context, c string, pair curren case krakenWsOHLC: return e.wsProcessCandle(ctx, c, response[1], pair) case krakenWsOrderbook: - return e.wsProcessOrderBook(ctx, c, response, pair) + return e.wsProcessOrderBook(ctx, conn, c, response, pair) default: return fmt.Errorf("received unidentified data for subscription %s: %+v", c, response) } @@ -454,7 +434,7 @@ func hasKey(raw json.RawMessage, key string) bool { } // wsProcessOrderBook handles both partial and full orderbook updates -func (e *Exchange) wsProcessOrderBook(ctx context.Context, c string, response []json.RawMessage, pair currency.Pair) error { +func (e *Exchange) wsProcessOrderBook(ctx context.Context, conn websocket.Connection, c string, response []json.RawMessage, pair currency.Pair) error { key := &subscription.Subscription{ Channel: c, Asset: asset.Spot, @@ -490,7 +470,7 @@ func (e *Exchange) wsProcessOrderBook(ctx context.Context, c string, response [] if errors.Is(err, errInvalidChecksum) { log.Debugf(log.Global, "%s Resubscribing to invalid %s orderbook", e.Name, pair) go func() { - if e2 := e.Websocket.ResubscribeToChannel(ctx, e.Websocket.Conn, s); e2 != nil && !errors.Is(e2, subscription.ErrInStateAlready) { + if e2 := e.Websocket.ResubscribeToChannel(ctx, conn, s); e2 != nil && !errors.Is(e2, subscription.ErrInStateAlready) { log.Errorf(log.ExchangeSys, "%s resubscription failure for %v: %v", e.Name, pair, e2) } }() @@ -678,75 +658,53 @@ func (e *Exchange) generateSubscriptions() (subscription.List, error) { return e.Features.Subscriptions.ExpandTemplates(e) } -// Subscribe adds a channel subscription to the websocket -func (e *Exchange) Subscribe(in subscription.List) error { - ctx := context.TODO() - in, errs := in.ExpandTemplates(e) - - // Collect valid new subs and add to websocket in Subscribing state - subs := subscription.List{} - for _, s := range in { - if s.State() != subscription.ResubscribingState { - if err := e.Websocket.AddSubscriptions(e.Websocket.Conn, s); err != nil { - errs = common.AppendError(errs, fmt.Errorf("%w; Channel: %s Pairs: %s", err, s.Channel, s.Pairs.Join())) - continue - } - } - subs = append(subs, s) - } - - // Merge subs by grouping pairs for request; We make a single request to subscribe to N+ pairs, but get N+ responses back - groupedSubs := subs.GroupPairs() - - errs = common.AppendError(errs, - e.ParallelChanOp(ctx, groupedSubs, func(ctx context.Context, s subscription.List) error { return e.manageSubs(ctx, krakenWsSubscribe, s) }, 1), - ) +func (e *Exchange) subscribeForConnection(ctx context.Context, conn websocket.Connection, subs subscription.List) error { + var errs error + // Keep per-pair keys in the store so inbound status/data messages can match + // (Kraken emits `subscriptionStatus` and updates per pair, even when requests are grouped). for _, s := range subs { - if s.State() != subscription.SubscribedState { - _ = s.SetState(subscription.InactiveState) - if err := e.Websocket.RemoveSubscriptions(e.Websocket.Conn, s); err != nil { - errs = common.AppendError(errs, fmt.Errorf("error removing failed subscription: %w; Channel: %s Pairs: %s", err, s.Channel, s.Pairs.Join())) - } + if s.State() == subscription.ResubscribingState { + continue + } + if err := e.Websocket.AddSubscriptions(conn, s); err != nil { + errs = common.AppendError(errs, fmt.Errorf("%w; Channel: %s Pairs: %s", err, s.Channel, s.Pairs.Join())) } } + errs = common.AppendError(errs, e.ParallelChanOp(ctx, subs.GroupPairs(), func(ctx context.Context, s subscription.List) error { + return e.manageSubs(ctx, krakenWsSubscribe, s, conn) + }, 1)) + + errs = common.AppendError(errs, e.cleanupUnsubscribedSubs(conn, subs)) return errs } -// Unsubscribe removes a channel subscriptions from the websocket -func (e *Exchange) Unsubscribe(keys subscription.List) error { - ctx := context.TODO() +func (e *Exchange) cleanupUnsubscribedSubs(conn websocket.Connection, subs subscription.List) error { var errs error - // Make sure we have the concrete subscriptions, since we will change the state - subs := make(subscription.List, 0, len(keys)) - for _, key := range keys { - if s := e.Websocket.GetSubscription(key); s == nil { - errs = common.AppendError(errs, fmt.Errorf("%w; Channel: %s Pairs: %s", subscription.ErrNotFound, key.Channel, key.Pairs.Join())) - } else { - if s.State() != subscription.ResubscribingState { - if err := s.SetState(subscription.UnsubscribingState); err != nil { - errs = common.AppendError(errs, fmt.Errorf("%w; Channel: %s Pairs: %s", err, s.Channel, s.Pairs.Join())) - continue - } - } - subs = append(subs, s) + for _, s := range subs { + if s.State() == subscription.SubscribedState { + continue + } + _ = s.SetState(subscription.InactiveState) + if err := e.Websocket.RemoveSubscriptions(conn, s); err != nil { + errs = common.AppendError(errs, fmt.Errorf("error removing failed subscription: %w; Channel: %s Pairs: %s", err, s.Channel, s.Pairs.Join())) } } + return errs +} - subs = subs.GroupPairs() - - return common.AppendError(errs, - e.ParallelChanOp(ctx, subs, func(ctx context.Context, s subscription.List) error { return e.manageSubs(ctx, krakenWsUnsubscribe, s) }, 1), - ) +func (e *Exchange) unsubscribeForConnection(ctx context.Context, conn websocket.Connection, subs subscription.List) error { + return e.ParallelChanOp(ctx, subs.GroupPairs(), func(ctx context.Context, s subscription.List) error { + return e.manageSubs(ctx, krakenWsUnsubscribe, s, conn) + }, 1) } // manageSubs handles both websocket channel subscribe and unsubscribe -func (e *Exchange) manageSubs(ctx context.Context, op string, subs subscription.List) error { +func (e *Exchange) manageSubs(ctx context.Context, op string, subs subscription.List, conn websocket.Connection) error { if len(subs) != 1 { return subscription.ErrBatchingNotSupported } - s := subs[0] if err := enforceStandardChannelNames(s); err != nil { @@ -769,10 +727,8 @@ func (e *Exchange) manageSubs(ctx context.Context, op string, subs subscription. r.Subscription.Interval = int(time.Duration(s.Interval).Minutes()) } - conn := e.Websocket.Conn if s.Authenticated { r.Subscription.Token = e.websocketAuthToken() - conn = e.Websocket.AuthConn } resps, err := conn.SendMessageReturnResponses(ctx, request.Unset, r.RequestID, r, len(s.Pairs)) @@ -883,7 +839,7 @@ func (e *Exchange) getRespErr(resp []byte) error { // wsProcessSubStatus handles creating or removing Subscriptions as soon as we receive a message // It's job is to ensure that subscription state is kept correct sequentially between WS messages // If this responsibility was moved to Subscribe then we would have a race due to the channel connecting IncomingWithData -func (e *Exchange) wsProcessSubStatus(resp []byte) { +func (e *Exchange) wsProcessSubStatus(conn websocket.Connection, resp []byte) { pName, err := jsonparser.GetUnsafeString(resp, "pair") if err != nil { return @@ -921,7 +877,7 @@ func (e *Exchange) wsProcessSubStatus(resp []byte) { if status == krakenWsSubscribed { err = s.SetState(subscription.SubscribedState) } else if s.State() != subscription.ResubscribingState { // Do not remove a resubscribing sub which just unsubbed - err = e.Websocket.RemoveSubscriptions(e.Websocket.Conn, s) + err = e.Websocket.RemoveSubscriptions(conn, s) if e2 := s.SetState(subscription.UnsubscribedState); e2 != nil { err = common.AppendError(err, e2) } @@ -978,15 +934,27 @@ func fqChannelNameSub(s *subscription.Subscription) error { return nil } +func (e *Exchange) wsAuthConnection() (websocket.Connection, error) { + wsRunningAuthURL, err := e.API.Endpoints.GetURL(exchange.WebsocketSpotSupplementary) + if err != nil { + return nil, err + } + return e.Websocket.GetConnection(wsRunningAuthURL) +} + // wsAddOrder creates an order, returned order ID if success func (e *Exchange) wsAddOrder(ctx context.Context, req *WsAddOrderRequest) (string, error) { if req == nil { return "", common.ErrNilPointer } + conn, err := e.wsAuthConnection() + if err != nil { + return "", err + } req.RequestID = e.MessageSequence() req.Event = krakenWsAddOrder req.Token = e.websocketAuthToken() - jsonResp, err := e.Websocket.AuthConn.SendMessageReturnResponse(ctx, request.Unset, req.RequestID, req) + jsonResp, err := conn.SendMessageReturnResponse(ctx, request.Unset, req.RequestID, req) if err != nil { return "", err } @@ -1011,15 +979,19 @@ func (e *Exchange) wsAddOrder(ctx context.Context, req *WsAddOrderRequest) (stri // wsCancelOrders cancels open orders concurrently // It does not use the multiple txId facility of the cancelOrder API because the errors are not specific func (e *Exchange) wsCancelOrders(ctx context.Context, orderIDs []string) error { + conn, err := e.wsAuthConnection() + if err != nil { + return err + } var errs common.ErrorCollector for _, id := range orderIDs { - errs.Go(func() error { return e.wsCancelOrder(ctx, id) }) + errs.Go(func() error { return e.wsCancelOrder(ctx, conn, id) }) } return errs.Collect() } // wsCancelOrder cancels an open order -func (e *Exchange) wsCancelOrder(ctx context.Context, orderID string) error { +func (e *Exchange) wsCancelOrder(ctx context.Context, conn websocket.Connection, orderID string) error { id := e.MessageSequence() req := WsCancelOrderRequest{ Event: krakenWsCancelOrder, @@ -1028,7 +1000,7 @@ func (e *Exchange) wsCancelOrder(ctx context.Context, orderID string) error { RequestID: id, } - resp, err := e.Websocket.AuthConn.SendMessageReturnResponse(ctx, request.Unset, id, req) + resp, err := conn.SendMessageReturnResponse(ctx, request.Unset, id, req) if err != nil { return fmt.Errorf("%w %s: %w", errCancellingOrder, orderID, err) } @@ -1051,13 +1023,17 @@ func (e *Exchange) wsCancelOrder(ctx context.Context, orderID string) error { // wsCancelAllOrders cancels all opened orders // Returns number (count param) of affected orders or 0 if no open orders found func (e *Exchange) wsCancelAllOrders(ctx context.Context) (*WsCancelOrderResponse, error) { + conn, err := e.wsAuthConnection() + if err != nil { + return &WsCancelOrderResponse{}, err + } req := WsCancelOrderRequest{ Event: krakenWsCancelAll, Token: e.websocketAuthToken(), RequestID: e.MessageSequence(), } - jsonResp, err := e.Websocket.AuthConn.SendMessageReturnResponse(ctx, request.Unset, req.RequestID, req) + jsonResp, err := conn.SendMessageReturnResponse(ctx, request.Unset, req.RequestID, req) if err != nil { return &WsCancelOrderResponse{}, err } diff --git a/exchanges/kraken/kraken_wrapper.go b/exchanges/kraken/kraken_wrapper.go index 59154ea0ada..19187265fde 100644 --- a/exchanges/kraken/kraken_wrapper.go +++ b/exchanges/kraken/kraken_wrapper.go @@ -195,24 +195,27 @@ func (e *Exchange) Setup(exch *config.Exchange) error { return err } err = e.Websocket.Setup(&websocket.ManagerSetup{ - ExchangeConfig: exch, - DefaultURL: krakenWSURL, - RunningURL: wsRunningURL, - Connector: e.WsConnect, - Subscriber: e.Subscribe, - Unsubscriber: e.Unsubscribe, - GenerateSubscriptions: e.generateSubscriptions, - Features: &e.Features.Supports.WebsocketCapabilities, - OrderbookBufferConfig: buffer.Config{SortBuffer: true}, + ExchangeConfig: exch, + UseMultiConnectionManagement: true, + Features: &e.Features.Supports.WebsocketCapabilities, + MaxWebsocketSubscriptionsPerConnection: 200, // https://docs.kraken.com/api/docs/websocket-v2/level3/ (200 symbols per connection) + OrderbookBufferConfig: buffer.Config{SortBuffer: true}, }) if err != nil { return err } err = e.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ - RateLimit: request.NewWeightedRateLimitByDuration(50 * time.Millisecond), - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + URL: wsRunningURL, + Connector: e.wsConnect, + Subscriber: e.subscribeForConnection, + Unsubscriber: e.unsubscribeForConnection, + GenerateSubscriptions: e.generatePublicSubscriptions, + Handler: e.wsHandleData, + RateLimit: request.NewWeightedRateLimitByDuration(50 * time.Millisecond), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + MessageFilter: wsRunningURL, }) if err != nil { return err @@ -222,12 +225,20 @@ func (e *Exchange) Setup(exch *config.Exchange) error { if err != nil { return err } + return e.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ - RateLimit: request.NewWeightedRateLimitByDuration(50 * time.Millisecond), - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Authenticated: true, - URL: wsRunningAuthURL, + URL: wsRunningAuthURL, + Connector: e.wsConnect, + Authenticate: e.wsAuthenticate, + Subscriber: e.subscribeForConnection, + Unsubscriber: e.unsubscribeForConnection, + GenerateSubscriptions: e.generatePrivateSubscriptions, + SubscriptionsNotRequired: true, + Handler: e.wsHandleData, + RateLimit: request.NewWeightedRateLimitByDuration(50 * time.Millisecond), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + MessageFilter: wsRunningAuthURL, }) } @@ -1390,17 +1401,6 @@ func (e *Exchange) GetOrderHistory(ctx context.Context, getOrdersRequest *order. return getOrdersRequest.Filter(e.Name, orders), nil } -// AuthenticateWebsocket sends an authentication message to the websocket -func (e *Exchange) AuthenticateWebsocket(ctx context.Context) error { - resp, err := e.GetWebsocketToken(ctx) - if err != nil { - return err - } - - e.setWebsocketAuthToken(resp) - return nil -} - // ValidateAPICredentials validates current credentials used for wrapper functionality func (e *Exchange) ValidateAPICredentials(ctx context.Context, assetType asset.Item) error { _, err := e.UpdateAccountBalances(ctx, assetType) diff --git a/exchanges/kraken/mock_ws_test.go b/exchanges/kraken/mock_ws_test.go index ef1675caa5b..dc6e9ed470a 100644 --- a/exchanges/kraken/mock_ws_test.go +++ b/exchanges/kraken/mock_ws_test.go @@ -8,6 +8,7 @@ import ( "github.com/buger/jsonparser" gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" ) @@ -22,6 +23,47 @@ func mockWsServer(tb testing.TB, msg []byte, w *gws.Conn) error { return mockWsCancelOrders(tb, msg, w) case krakenWsAddOrder: return mockWsAddOrder(tb, msg, w) + case krakenWsSubscribe, krakenWsUnsubscribe: + return mockWsSub(tb, msg, w, event) + } + return nil +} + +func mockWsSub(tb testing.TB, msg []byte, w *gws.Conn, event string) error { + tb.Helper() + var req WebsocketSubRequest + if err := json.Unmarshal(msg, &req); err != nil { + return err + } + status := event + "d" + channelName := req.Subscription.Name + switch channelName { + case "book": + channelName += fmt.Sprintf("-%d", req.Subscription.Depth) + case "ohlc": + channelName += fmt.Sprintf("-%d", req.Subscription.Interval) + } + + for _, p := range req.Pairs { + pair, err := currency.NewPairDelimiter(p, "/") + if err != nil { + return err + } + resp := WebsocketEventResponse{ + Event: krakenWsSubscriptionStatus, + Status: status, + RequestID: req.RequestID, + ChannelName: channelName, + Pair: pair, + } + resp.Subscription.Name = req.Subscription.Name + raw, err := json.Marshal(resp) + if err != nil { + return err + } + if err := w.WriteMessage(gws.TextMessage, raw); err != nil { + return err + } } return nil } diff --git a/exchanges/okx/okx_business_websocket.go b/exchanges/okx/okx_business_websocket.go index c50a27bdb44..1ee39150885 100644 --- a/exchanges/okx/okx_business_websocket.go +++ b/exchanges/okx/okx_business_websocket.go @@ -16,8 +16,6 @@ import ( const ( // okxBusinessWebsocketURL okxBusinessWebsocketURL = "wss://ws.okx.com:8443/ws/v5/business" - - businessConnection = "business" ) var ( diff --git a/exchanges/okx/okx_test.go b/exchanges/okx/okx_test.go index 4132b87cdc6..dd7e684eec5 100644 --- a/exchanges/okx/okx_test.go +++ b/exchanges/okx/okx_test.go @@ -3914,7 +3914,8 @@ func TestOrderPushData(t *testing.T) { t.Parallel() e := new(Exchange) require.NoError(t, testexch.Setup(e), "Test instance Setup must not error") - testexch.FixtureToDataHandler(t, "testdata/wsOrders.json", func(ctx context.Context, b []byte) error { return e.wsHandleData(ctx, nil, b) }) + conn := testexch.GetMockConn(t, e, "") + testexch.FixtureToDataHandler(t, "testdata/wsOrders.json", func(ctx context.Context, b []byte) error { return e.wsHandleData(ctx, conn, b) }) e.Websocket.DataHandler.Close() require.Len(t, e.Websocket.DataHandler.C, 4, "Should see 4 orders") for resp := range e.Websocket.DataHandler.C { @@ -4018,6 +4019,7 @@ func TestWsHandleData(t *testing.T) { t.Parallel() e := new(Exchange) require.NoError(t, testexch.Setup(e), "Setup must not error") + conn := testexch.GetMockConn(t, e, "") for name, msg := range pushDataMap { switch name { @@ -4029,7 +4031,7 @@ func TestWsHandleData(t *testing.T) { e.API.AuthenticatedSupport = false e.API.AuthenticatedWebsocketSupport = false } - err := e.wsHandleData(t.Context(), nil, []byte(msg)) + err := e.wsHandleData(t.Context(), conn, []byte(msg)) if name == "Balance Save Error" { assert.ErrorIs(t, err, exchange.ErrAuthenticationSupportNotEnabled, "wsProcessBalanceAndPosition Accounts.Save should error without credentials") } else { @@ -4046,8 +4048,9 @@ func TestPushDataDynamic(t *testing.T) { "Snapshot OrderBook": `{"arg":{"channel":"books","instId":"BTC-USD-SWAP"},"action":"snapshot","data":[{"asks":[["0.07026","5","0","1"],["0.07027","765","0","3"],["0.07028","110","0","1"],["0.0703","1264","0","1"],["0.07034","280","0","1"],["0.07035","2255","0","1"],["0.07036","28","0","1"],["0.07037","63","0","1"],["0.07039","137","0","2"],["0.0704","48","0","1"],["0.07041","32","0","1"],["0.07043","3985","0","1"],["0.07057","257","0","1"],["0.07058","7870","0","1"],["0.07059","161","0","1"],["0.07061","4539","0","1"],["0.07068","1438","0","3"],["0.07088","3162","0","1"],["0.07104","99","0","1"],["0.07108","5018","0","1"],["0.07115","1540","0","1"],["0.07129","5080","0","1"],["0.07145","1512","0","1"],["0.0715","5016","0","1"],["0.07171","5026","0","1"],["0.07192","5062","0","1"],["0.07197","1517","0","1"],["0.0726","1511","0","1"],["0.07314","10376","0","1"],["0.07354","1","0","1"],["0.07466","10277","0","1"],["0.07626","269","0","1"],["0.07636","269","0","1"],["0.0809","1","0","1"],["0.08899","1","0","1"],["0.09789","1","0","1"],["0.10768","1","0","1"]],"bids":[["0.07014","56","0","2"],["0.07011","608","0","1"],["0.07009","110","0","1"],["0.07006","1264","0","1"],["0.07004","2347","0","3"],["0.07003","279","0","1"],["0.07001","52","0","1"],["0.06997","91","0","1"],["0.06996","4242","0","2"],["0.06995","486","0","1"],["0.06992","161","0","1"],["0.06991","63","0","1"],["0.06988","7518","0","1"],["0.06976","186","0","1"],["0.06975","71","0","1"],["0.06973","1086","0","1"],["0.06961","513","0","2"],["0.06959","4603","0","1"],["0.0695","186","0","1"],["0.06946","3043","0","1"],["0.06939","103","0","1"],["0.0693","5053","0","1"],["0.06909","5039","0","1"],["0.06888","5037","0","1"],["0.06886","1526","0","1"],["0.06867","5008","0","1"],["0.06846","5065","0","1"],["0.06826","1572","0","1"],["0.06801","1565","0","1"],["0.06748","67","0","1"],["0.0674","111","0","1"],["0.0672","10038","0","1"],["0.06652","1","0","1"],["0.06625","1526","0","1"],["0.06619","10924","0","1"],["0.05986","1","0","1"],["0.05387","1","0","1"],["0.04848","1","0","1"],["0.04363","1","0","1"]],"ts":"1659792392540","checksum":-1462286744}]}`, } var err error + conn := testexch.GetMockConn(t, e, "") for x := range dataMap { - err = e.wsHandleData(t.Context(), nil, []byte(dataMap[x])) + err = e.wsHandleData(t.Context(), conn, []byte(dataMap[x])) require.NoError(t, err) } } @@ -4134,7 +4137,7 @@ func TestWSProcessTrades(t *testing.T) { p := currency.NewPairWithDelimiter("BTC", "USDT", currency.DashDelimiter) for _, a := range assets { - err := e.Websocket.AddSubscriptions(e.Websocket.Conn, &subscription.Subscription{ + err := e.Websocket.AddSubscriptions(nil, &subscription.Subscription{ Asset: a, Pairs: currency.Pairs{p}, Channel: subscription.AllTradesChannel, @@ -4142,7 +4145,8 @@ func TestWSProcessTrades(t *testing.T) { }) require.NoError(t, err, "AddSubscriptions must not error") } - testexch.FixtureToDataHandler(t, "testdata/wsAllTrades.json", func(ctx context.Context, b []byte) error { return e.wsHandleData(ctx, nil, b) }) + conn := testexch.GetMockConn(t, e, "") + testexch.FixtureToDataHandler(t, "testdata/wsAllTrades.json", func(ctx context.Context, b []byte) error { return e.wsHandleData(ctx, conn, b) }) exp := []trade.Data{ { @@ -6202,7 +6206,9 @@ func TestBusinessWSCandleSubscriptions(t *testing.T) { require.NoError(t, e.Websocket.Connect(t.Context())) - conn, err := e.Websocket.GetConnection(businessConnection) + wsBusinessURL, err := e.API.Endpoints.GetURL(exchange.WebsocketSpotSupplementary) + require.NoError(t, err) + conn, err := e.Websocket.GetConnection(wsBusinessURL) require.NoError(t, err, "GetConnection must not error") err = e.BusinessSubscribe(t.Context(), conn, subscription.List{{Channel: channelCandle1D}}) diff --git a/exchanges/okx/okx_wrapper.go b/exchanges/okx/okx_wrapper.go index 570cb966ce9..b31ec69497f 100644 --- a/exchanges/okx/okx_wrapper.go +++ b/exchanges/okx/okx_wrapper.go @@ -250,7 +250,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { Unsubscriber: e.Unsubscribe, Handler: e.wsHandleData, Authenticate: e.wsAuthenticateConnection, - MessageFilter: privateConnection, + MessageFilter: wsPrivate, }); err != nil { return err } @@ -271,7 +271,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { Unsubscriber: e.BusinessUnsubscribe, Handler: e.wsHandleData, Authenticate: e.wsAuthenticateConnection, - MessageFilter: businessConnection, + MessageFilter: wsBusiness, }) } diff --git a/exchanges/okx/ws_requests.go b/exchanges/okx/ws_requests.go index 00919265870..c5210a3daa6 100644 --- a/exchanges/okx/ws_requests.go +++ b/exchanges/okx/ws_requests.go @@ -8,6 +8,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/encoding/json" + exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/request" ) @@ -19,8 +20,6 @@ var ( errMassCancelFailed = errors.New("mass cancel failed") errCancelAllSpreadOrdersFailed = errors.New("cancel all spread orders failed") errMultipleItemsReturned = errors.New("multiple items returned") - - privateConnection = "private" ) // WSPlaceOrder submits an order @@ -255,7 +254,11 @@ func (e *Exchange) SendAuthenticatedWebsocketRequest(ctx context.Context, epl re return errInvalidWebsocketRequest } - conn, err := e.Websocket.GetConnection(privateConnection) + wsPrivateURL, err := e.API.Endpoints.GetURL(exchange.WebsocketPrivate) + if err != nil { + return err + } + conn, err := e.Websocket.GetConnection(wsPrivateURL) if err != nil { return err } diff --git a/internal/testing/exchange/exchange.go b/internal/testing/exchange/exchange.go index f130bbda79c..32cd36cb716 100644 --- a/internal/testing/exchange/exchange.go +++ b/internal/testing/exchange/exchange.go @@ -14,6 +14,7 @@ import ( "sync" "testing" + gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/config" @@ -23,6 +24,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchange/websocket" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/mock" + "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" testutils "github.com/thrasher-corp/gocryptotrader/internal/testing/utils" @@ -253,3 +255,64 @@ func UpdatePairsOnce(tb testing.TB, e exchange.IBotExchange) { cache.Load(&b.CurrencyPairs) updatePairsOnce[e.GetName()] = cache } + +// GetMockConn returns a mock websocket connection that can be used to test websocket handlers without needing to connect to a real websocket server +func GetMockConn(tb testing.TB, e exchange.IBotExchange, u string) websocket.Connection { + tb.Helper() + b := e.GetBase() + return &mockConn{ + match: b.Websocket.Match, + url: u, + } +} + +type mockConn struct { + match *websocket.Match + url string +} + +func (m *mockConn) Dial(context.Context, *gws.Dialer, http.Header, url.Values) error { + return nil +} +func (m *mockConn) ReadMessage() websocket.Response { return websocket.Response{} } +func (m *mockConn) SetupPingHandler(request.EndpointLimit, websocket.PingHandler) { +} + +func (m *mockConn) SendMessageReturnResponse(context.Context, request.EndpointLimit, any, any) ([]byte, error) { + return nil, nil +} + +func (m *mockConn) SendMessageReturnResponses(context.Context, request.EndpointLimit, any, any, int) ([][]byte, error) { + return nil, nil +} + +func (m *mockConn) SendMessageReturnResponsesWithInspector(context.Context, request.EndpointLimit, any, any, int, websocket.Inspector) ([][]byte, error) { + return nil, nil +} + +func (m *mockConn) SendRawMessage(context.Context, request.EndpointLimit, int, []byte) error { + return nil +} + +func (m *mockConn) SendJSONMessage(context.Context, request.EndpointLimit, any) error { + return nil +} +func (m *mockConn) SetURL(u string) { m.url = u } +func (m *mockConn) SetProxy(string) {} +func (m *mockConn) GetURL() string { return m.url } +func (m *mockConn) Shutdown() error { return nil } +func (m *mockConn) RequireMatchWithData(signature any, incoming []byte) error { + if m.match == nil { + return nil + } + return m.match.RequireMatchWithData(signature, incoming) +} + +func (m *mockConn) IncomingWithData(signature any, data []byte) bool { + return m.match != nil && m.match.IncomingWithData(signature, data) +} + +func (m *mockConn) MatchReturnResponses(context.Context, any, int) (<-chan websocket.MatchedResponse, error) { + return nil, nil +} +func (m *mockConn) Subscriptions() *subscription.Store { return nil }