Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
360d0e5
Refactor websocket handling to support context propagation
Sep 22, 2025
4be8b4b
linter/AI: nits and fixes
Sep 22, 2025
cce08e7
linter: fixes again
Sep 22, 2025
86d218c
fix test
Sep 22, 2025
0cb3876
Merge branch 'master' into websock_read_ctx
Oct 14, 2025
98c0817
instead of a read function just expose the read only channel field li…
Oct 14, 2025
a87e39f
Merge branch 'master' into websock_read_ctx
Oct 28, 2025
6b0a3b5
Merge branch 'master' into websock_read_ctx
Nov 5, 2025
f380808
glorious: nits
Nov 9, 2025
5eacbae
Merge branch 'master' into websock_read_ctx
Nov 12, 2025
88e48a1
glorious: nits
Nov 12, 2025
c8fb278
rm jank for finding issue
Nov 12, 2025
f6dd7c3
linter: fix
Nov 12, 2025
fc7df14
glorious: more nits
Nov 12, 2025
770d7a6
Add todo for context removal
Nov 12, 2025
7d6e595
Merge branch 'master' into websock_read_ctx
Nov 27, 2025
cabcd83
Merge branch 'master' into websock_read_ctx
Nov 30, 2025
f21e02d
Implement context freezing and thawing functions; update Payload to u…
Dec 1, 2025
7625f60
linter: fix
Dec 1, 2025
7185eb9
linter: fix again
Dec 1, 2025
2eb5f6a
Merge branch 'master' into websock_read_ctx
Dec 1, 2025
e42ae00
refactor: simplify FrozenContext structure and update related functions
Dec 2, 2025
bab00e0
Update common/common.go
shazbert Dec 2, 2025
1680819
Update common/common.go
shazbert Dec 2, 2025
239d160
glorious: removal of check
Dec 3, 2025
07bba2a
thrasher: rm unused error return
Dec 3, 2025
1e63153
Update exchange/websocket/manager.go
shazbert Dec 3, 2025
9f5fbda
rm: check for nil datahandler
Dec 3, 2025
ca24dc1
gk: rm rpcContextToLongLivedSession
Dec 4, 2025
db647bf
gk: rn package
Dec 4, 2025
abd83f7
thrasher-: nits
Jan 21, 2026
bd8cd65
Merge branch 'master' into websock_read_ctx
Jan 21, 2026
8a5069f
what a silly billy
Jan 21, 2026
478b2eb
Merge branch 'master' into websock_read_ctx
Jan 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -696,3 +696,58 @@ func SetIfZero[T comparable](p *T, def T) bool {
*p = def
return true
}

var (
contextKeys []any
contextKeysMu sync.RWMutex
)

// RegisterContextKey registers a key to be captured by FreezeContext
func RegisterContextKey(key any) {
contextKeysMu.Lock()
defer contextKeysMu.Unlock()
if !slices.Contains(contextKeys, key) {
contextKeys = append(contextKeys, key)
}
}

// FrozenContext holds captured context values
type FrozenContext map[any]any

// FreezeContext captures values from the context for registered keys
func FreezeContext(ctx context.Context) FrozenContext {
contextKeysMu.RLock()
defer contextKeysMu.RUnlock()

values := make(FrozenContext, len(contextKeys))
for _, key := range contextKeys {
if val := ctx.Value(key); val != nil {
values[key] = val
}
}
return values
}

// ThawContext creates a new context from the frozen context using context.Background() as parent
func ThawContext(fc FrozenContext) context.Context {
return MergeContext(context.Background(), fc)
}

// MergeContext adds the frozen values to an existing context
func MergeContext(ctx context.Context, fc FrozenContext) context.Context {
return &mergedContext{Context: ctx, frozen: fc}
}

// mergedContext is a context that has merged values from a frozen context and a parent context.
// frozen values are stored in FrozenContext instead of nested context.WithValue because of the performance of calling WithValue N+ times on messages being frozen
type mergedContext struct {
context.Context //nolint:containedctx // mergedContext implements context.Context
frozen FrozenContext
}

func (m *mergedContext) Value(key any) any {
if val, ok := m.frozen[key]; ok {
return val
}
return m.Context.Value(key)
}
34 changes: 34 additions & 0 deletions common/common_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package common

import (
"context"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -691,3 +692,36 @@ func TestSetIfZero(t *testing.T) {
assert.True(t, changed, "SetIfZero should change a zero value")
assert.Equal(t, "world", s, "SetIfZero should change a zero value")
}

func TestContextFunctions(t *testing.T) {
t.Parallel()

type key string
const k1 key = "key1"
const k2 key = "key2"
const k3 key = "key3"

RegisterContextKey(k1)
RegisterContextKey(k2)

ctx := context.WithValue(context.Background(), k1, "value1")
ctx = context.WithValue(ctx, k2, "value2")
ctx = context.WithValue(ctx, k3, "value3") // Not registered

frozen := FreezeContext(ctx)

assert.Equal(t, "value1", frozen[k1], "should have captured k1")
assert.Equal(t, "value2", frozen[k2], "should have captured k2")
assert.Zero(t, frozen[k3], "k3 should not be captured")

thawed := ThawContext(frozen)
assert.Equal(t, "value1", thawed.Value(k1), "should have k1 after thaw")
assert.Equal(t, "value2", thawed.Value(k2), "should have k2 after thaw")
assert.Nil(t, thawed.Value(k3), "Thawed context should not have k3")

ctx2 := context.WithValue(context.Background(), k3, "value3_new")
merged := MergeContext(ctx2, frozen)
assert.Equal(t, "value1", merged.Value(k1), "should have k1 from frozen")
assert.Equal(t, "value2", merged.Value(k2), "should have k2 from frozen")
assert.Equal(t, "value3_new", merged.Value(k3), "should have k3 from parent")
}
30 changes: 13 additions & 17 deletions docs/ADD_NEW_EXCHANGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -693,19 +693,17 @@ func (e *Exchange) WsConnect() error {
// KeepAuthKeyAlive will continuously send messages to
// keep the WS auth key active
func (e *Exchange) KeepAuthKeyAlive(ctx context.Context) {
e.Websocket.Wg.Add(1)
defer e.Websocket.Wg.Done()
ticks := time.NewTicker(time.Minute * 30)
for {
select {
case <-e.Websocket.ShutdownC:
ticks.Stop()
return
case <-ticks.C:
err := e.MaintainWsAuthStreamKey(ctx)
if err != nil {
e.Websocket.DataHandler <- err
log.Warnf(log.ExchangeSys, "%s - Unable to renew auth websocket token, may experience shutdown", e.Name)
case <-time.After(time.Minute * 30):
if err := e.MaintainWsAuthStreamKey(ctx); err != nil {
if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil {
log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err)
}
log.Warnf(log.ExchangeSys, "%s %s: Unable to renew auth websocket token, may experience shutdown", e.Name, e.Websocket.Conn.GetURL())
}
}
}
Expand Down Expand Up @@ -817,9 +815,7 @@ Run gocryptotrader with the following settings enabled in config
```go
// wsReadData gets and passes on websocket messages for processing
func (e *Exchange) wsReadData() {
e.Websocket.Wg.Add(1)
defer e.Websocket.Wg.Done()

for {
select {
case <-e.Websocket.ShutdownC:
Expand All @@ -829,10 +825,10 @@ func (e *Exchange) wsReadData() {
if resp.Raw == nil {
return
}

err := e.wsHandleData(resp.Raw)
if err != nil {
e.Websocket.DataHandler <- err
if err := e.wsHandleData(ctx, resp.Raw); err != nil {
if errSend := e.Websocket.DataHandler.Send(ctx, err); errSend != nil {
log.Errorf(log.WebsocketMgr, "%s %s: %s %s", e.Name, e.Websocket.Conn.GetURL(), errSend, err)
}
}
}
}
Expand Down Expand Up @@ -875,15 +871,15 @@ If a suitable struct does not exist in wshandler, wrapper types are the next pre
if err := json.Unmarshal(respRaw, &resultData);err != nil {
return err
}
e.Websocket.DataHandler <- &ticker.Price{
return e.Websocket.DataHandler.Send(ctx, &ticker.Price{
ExchangeName: e.Name,
Bid: resultData.Ticker.Bid,
Ask: resultData.Ticker.Ask,
Last: resultData.Ticker.Last,
LastUpdated: resultData.Ticker.Time,
Pair: p,
AssetType: a,
}
})
}
```

Expand All @@ -896,7 +892,7 @@ If neither of those provide a suitable struct to store the data in, the data can
if err != nil {
return err
}
e.Websocket.DataHandler <- resultData.FillsData
return e.Websocket.DataHandler.Send(ctx, resultData.FillsData)
```

- Data Handling can be tested offline similar to the following example:
Expand Down
15 changes: 5 additions & 10 deletions engine/rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2947,7 +2947,7 @@ func (s *RPCServer) WebsocketGetInfo(_ context.Context, r *gctrpc.WebsocketGetIn
}

// WebsocketSetEnabled enables or disables the websocket client
func (s *RPCServer) WebsocketSetEnabled(_ context.Context, r *gctrpc.WebsocketSetEnabledRequest) (*gctrpc.GenericResponse, error) {
func (s *RPCServer) WebsocketSetEnabled(ctx context.Context, r *gctrpc.WebsocketSetEnabledRequest) (*gctrpc.GenericResponse, error) {
exch, err := s.GetExchangeByName(r.Exchange)
if err != nil {
return nil, err
Expand All @@ -2964,11 +2964,9 @@ func (s *RPCServer) WebsocketSetEnabled(_ context.Context, r *gctrpc.WebsocketSe
}

if r.Enable {
err = w.Enable()
if err != nil {
if err := w.Enable(context.WithoutCancel(ctx)); err != nil {
return nil, err
}

exchCfg.Features.Enabled.Websocket = true
return &gctrpc.GenericResponse{Status: MsgStatusSuccess, Data: "websocket enabled"}, nil
}
Expand Down Expand Up @@ -3013,7 +3011,7 @@ func (s *RPCServer) WebsocketGetSubscriptions(_ context.Context, r *gctrpc.Webso
}

// WebsocketSetProxy sets client websocket connection proxy
func (s *RPCServer) WebsocketSetProxy(_ context.Context, r *gctrpc.WebsocketSetProxyRequest) (*gctrpc.GenericResponse, error) {
func (s *RPCServer) WebsocketSetProxy(ctx context.Context, r *gctrpc.WebsocketSetProxyRequest) (*gctrpc.GenericResponse, error) {
exch, err := s.GetExchangeByName(r.Exchange)
if err != nil {
return nil, err
Expand All @@ -3024,15 +3022,12 @@ func (s *RPCServer) WebsocketSetProxy(_ context.Context, r *gctrpc.WebsocketSetP
return nil, fmt.Errorf("websocket not supported for exchange %s", r.Exchange)
}

err = w.SetProxyAddress(r.Proxy)
if err != nil {
if err := w.SetProxyAddress(context.WithoutCancel(ctx), r.Proxy); err != nil {
return nil, err
}
return &gctrpc.GenericResponse{
Status: MsgStatusSuccess,
Data: fmt.Sprintf("new proxy has been set [%s] for %s websocket connection",
r.Exchange,
r.Proxy),
Data: fmt.Sprintf("new proxy has been set [%s] for %s websocket connection", r.Exchange, r.Proxy),
}, nil
}

Expand Down
10 changes: 5 additions & 5 deletions engine/websocketroutine_manager.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package engine

import (
"context"
"fmt"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -139,7 +140,7 @@ func (m *WebsocketRoutineManager) websocketRoutine() {
log.Errorf(log.WebsocketMgr, "%v", err)
}

if err := ws.Connect(); err != nil {
if err := ws.Connect(context.TODO()); err != nil {
log.Errorf(log.WebsocketMgr, "%v", err)
}
})
Expand Down Expand Up @@ -167,14 +168,13 @@ func (m *WebsocketRoutineManager) websocketDataReceiver(ws *websocket.Manager) e
select {
case <-m.shutdown:
return
case data := <-ws.ToRoutine:
if data == nil {
case payload := <-ws.DataHandler.C:
if payload.Data == nil {
log.Errorf(log.WebsocketMgr, "exchange %s nil data sent to websocket", ws.GetName())
}
m.mu.RLock()
for x := range m.dataHandlers {
err := m.dataHandlers[x](ws.GetName(), data)
if err != nil {
if err := m.dataHandlers[x](ws.GetName(), payload.Data); err != nil {
log.Errorln(log.WebsocketMgr, err)
}
}
Expand Down
10 changes: 6 additions & 4 deletions engine/websocketroutine_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,16 +258,18 @@ func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) {
}

mock := websocket.NewManager()
mock.ToRoutine = make(chan any)
m.state = readyState
err = m.websocketDataReceiver(mock)
if err != nil {
t.Fatal(err)
}

mock.ToRoutine <- nil
mock.ToRoutine <- 1336
mock.ToRoutine <- "intercepted"
err = mock.DataHandler.Send(t.Context(), nil)
require.NoError(t, err)
err = mock.DataHandler.Send(t.Context(), 1336)
require.NoError(t, err)
err = mock.DataHandler.Send(t.Context(), "intercepted")
require.NoError(t, err)

if r := <-dataChan; r != "intercepted" {
t.Fatal("unexpected value received")
Expand Down
48 changes: 48 additions & 0 deletions exchange/stream/relay.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package stream

import (
"context"
"errors"
"fmt"

"github.com/thrasher-corp/gocryptotrader/common"
)

var errChannelBufferFull = errors.New("channel buffer is full")

// Relay defines a channel relay for messages
type Relay struct {
C <-chan Payload
comm chan Payload
}

// Payload represents a relayed message with a context
type Payload struct {
Ctx common.FrozenContext
Data any
}

// NewRelay creates a new Relay instance with a specified buffer size
func NewRelay(buffer uint) *Relay {
if buffer == 0 {
panic("buffer size must be greater than 0")
}
comm := make(chan Payload, buffer)
return &Relay{comm: comm, C: comm}
}

// Send sends a message to the channel receiver
// This is non-blocking and returns an error if the channel buffer is full
func (r *Relay) Send(ctx context.Context, data any) error {
select {
case r.comm <- Payload{Ctx: common.FreezeContext(ctx), Data: data}:
return nil
default:
return fmt.Errorf("%w: failed to relay <%T>", errChannelBufferFull, data)
}
}

// Close closes the relay channel
func (r *Relay) Close() {
close(r.comm)
}
43 changes: 43 additions & 0 deletions exchange/stream/relay_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package stream

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewRelay(t *testing.T) {
t.Parallel()
assert.Panics(t, func() { NewRelay(0) }, "buffer size should be greater than 0")
r := NewRelay(5)
require.NotNil(t, r)
assert.Equal(t, 5, cap(r.comm))
}

func TestSend(t *testing.T) {
t.Parallel()
r := NewRelay(1)
require.NotNil(t, r)
assert.NoError(t, r.Send(t.Context(), "test"))
assert.ErrorIs(t, r.Send(t.Context(), "overflow"), errChannelBufferFull)
}

func TestRead(t *testing.T) {
t.Parallel()
r := NewRelay(1)
require.NotNil(t, r)
require.Empty(t, r.C)
assert.NoError(t, r.Send(t.Context(), "test"))
require.Len(t, r.C, 1)
assert.Equal(t, "test", (<-r.C).Data)
}

func TestClose(t *testing.T) {
t.Parallel()
r := NewRelay(1)
require.NotNil(t, r)
r.Close()
_, ok := <-r.C
assert.False(t, ok)
}
Loading
Loading