Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ var (
ErrSettingField = errors.New("error setting field")
ErrParsingWSField = errors.New("error parsing websocket field")
ErrMalformedData = errors.New("malformed data")
ErrFatal = errors.New("fatal error")
)

var (
Expand Down
21 changes: 15 additions & 6 deletions exchange/websocket/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"time"

gws "github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/encoding/json"
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
Expand Down Expand Up @@ -56,13 +55,17 @@ type Connection interface {
IncomingWithData(signature any, data []byte) bool
// MatchReturnResponses sets up a channel to listen for an expected number of responses.
MatchReturnResponses(ctx context.Context, signature any, expected int) (<-chan MatchedResponse, error)
// Subscriptions returns the subscription store for the connection
Subscriptions() *subscription.Store
}

// ConnectionSetup defines variables for an individual stream connection
type ConnectionSetup struct {
ResponseCheckTimeout time.Duration
ResponseMaxLimit time.Duration
RateLimit *request.RateLimiterWithWeight
ResponseCheckTimeout time.Duration
ResponseMaxLimit time.Duration
RateLimit *request.RateLimiterWithWeight
// ConnectionRateLimiter returns a new rate limiter for each connection instance
ConnectionRateLimiter func() *request.RateLimiterWithWeight
Authenticated bool // unused for multi-connection websocket
SubscriptionsNotRequired bool
ConnectionLevelReporter Reporter
Expand Down Expand Up @@ -109,6 +112,7 @@ type Response struct {

// connection contains all the data needed to send a message to a websocket connection
type connection struct {
subscriptions *subscription.Store
Verbose bool
connected int32
writeControl sync.Mutex // Gorilla websocket does not allow more than one goroutine to utilise write methods
Expand Down Expand Up @@ -328,8 +332,8 @@ func (c *connection) parseBinaryResponse(resp []byte) ([]byte, error) {

// Shutdown shuts down and closes specific connection
func (c *connection) Shutdown() error {
if err := common.NilGuard(c, c.Connection); err != nil {
return err
if c == nil || c.Connection == nil {
return nil // Allow Shutdown to be called during early startup/teardown when the socket hasn't been created yet.
}
c.setConnectedStatus(false)
c.writeControl.Lock()
Expand Down Expand Up @@ -471,3 +475,8 @@ func (c *connection) RequireMatchWithData(signature any, incoming []byte) error
func (c *connection) IncomingWithData(signature any, data []byte) bool {
return c.Match.IncomingWithData(signature, data)
}

// Subscriptions returns the subscription store for the connection
func (c *connection) Subscriptions() *subscription.Store {
return c.subscriptions
}
11 changes: 11 additions & 0 deletions exchange/websocket/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions"
)

func TestMatchReturnResponses(t *testing.T) {
Expand Down Expand Up @@ -59,3 +61,12 @@ func TestIncomingWithData(t *testing.T) {
require.Len(t, ch, 1, "must have one item in channel")
assert.Equal(t, []byte("test"), <-ch)
}

func TestConnectionSubscriptions(t *testing.T) {
t.Parallel()
ws := &connection{}
require.Nil(t, ws.Subscriptions())
ws.subscriptions = subscription.NewStore()
require.NotNil(t, ws.Subscriptions())
testsubs.EqualLists(t, ws.subscriptions.List(), ws.Subscriptions().List())
}
Loading
Loading