Skip to content

Commit 56cc4de

Browse files
authored
Merge pull request #8 from caio-northfleet/race-issues
mitigated specific client race issues
2 parents 10fc48e + 7e396c4 commit 56cc4de

File tree

3 files changed

+493
-49
lines changed

3 files changed

+493
-49
lines changed

client.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ import (
2020
"os/signal"
2121
"strconv"
2222
"sync"
23+
"sync/atomic"
2324
"syscall"
25+
"unsafe"
2426

2527
"github.com/scmhub/ibapi/protobuf"
2628
"google.golang.org/protobuf/proto"
@@ -340,9 +342,9 @@ func (c *EClient) reset() {
340342
}
341343

342344
func (c *EClient) setConnState(state ConnState) {
343-
cs := c.connState
344-
c.connState = state
345-
log.Debug().Stringer("from", cs).Stringer("to", c.connState).Msg("connection state changed")
345+
cs := ConnState(atomic.LoadInt32((*int32)(unsafe.Pointer(&c.connState))))
346+
atomic.StoreInt32((*int32)(unsafe.Pointer(&c.connState)), int32(state))
347+
log.Debug().Stringer("from", cs).Stringer("to", state).Msg("connection state changed")
346348
}
347349

348350
// request is a goroutine that will get the req from reqChan and send it to TWS.
@@ -611,7 +613,7 @@ func (c *EClient) Ctx() context.Context {
611613

612614
// IsConnected checks connection to TWS or GateWay.
613615
func (c *EClient) IsConnected() bool {
614-
return c.conn.IsConnected() && c.connState == CONNECTED
616+
return c.conn.IsConnected() && ConnState(atomic.LoadInt32((*int32)(unsafe.Pointer(&c.connState)))) == CONNECTED
615617
}
616618

617619
// OptionalCapabilities returns the Optional Capabilities.
@@ -1117,7 +1119,7 @@ func (c *EClient) CancelCalculateOptionPrice(reqID int64) {
11171119
//
11181120
// exerciseQuantity is the quantity you want to exercise.
11191121
// account is the destination account.
1120-
// override specifies whether your setting will override the system's natural action.
1122+
// override specifies whether your setting will override the system's natural action.
11211123
// For example, if your action is "exercise" and the option is not in-the-money, by natural action the option would not exercise.
11221124
// If you have override set to "yes" the natural action would be overridden and the out-of-the money option would be exercised.
11231125
// Values: 0 = no, 1 = yes.

connection.go

Lines changed: 133 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package ibapi
33
import (
44
"fmt"
55
"net"
6+
"sync"
7+
"sync/atomic"
68
"time"
79
)
810

@@ -11,91 +13,151 @@ const (
1113
reconnectDelay = 500 * time.Millisecond
1214
)
1315

14-
// Connection is a TCPConn wrapper.
16+
// Connection is a TCPConn wrapper with lock-free statistics and minimal contention.
1517
type Connection struct {
16-
*net.TCPConn
17-
wrapper EWrapper
18-
host string
19-
port int
20-
isConnected bool
21-
numBytesSent int
22-
numMsgSent int
23-
numBytesRecv int
24-
numMsgRecv int
18+
// Connection state - protected by mutex for host/port coordination only
19+
mu sync.RWMutex
20+
tcpConn atomic.Pointer[net.TCPConn] // Lock-free pointer for maximum performance
21+
wrapper EWrapper
22+
host string
23+
port int
24+
isConnected int32 // atomic: 0=disconnected, 1=connected
25+
26+
// Statistics - lock-free atomic counters for maximum performance
27+
numBytesSent int64 // atomic
28+
numMsgSent int64 // atomic
29+
numBytesRecv int64 // atomic
30+
numMsgRecv int64 // atomic
31+
32+
// Reconnection control - prevents multiple concurrent reconnections
33+
reconnecting int32 // atomic: 0=not reconnecting, 1=reconnecting
2534
}
2635

2736
func (c *Connection) Write(bs []byte) (int, error) {
28-
// first attempt
29-
n, err := c.TCPConn.Write(bs)
30-
if err == nil {
31-
c.numBytesSent += n
32-
c.numMsgSent++
33-
log.Trace().Int("nBytes", n).Msg("conn write")
34-
return n, nil
37+
// Fast path: try write with current connection
38+
conn := c.getConn()
39+
if conn != nil {
40+
n, err := conn.Write(bs)
41+
if err == nil {
42+
// Lock-free atomic statistics update
43+
atomic.AddInt64(&c.numBytesSent, int64(n))
44+
atomic.AddInt64(&c.numMsgSent, 1)
45+
log.Trace().Int("nBytes", n).Msg("conn write")
46+
return n, nil
47+
}
48+
49+
// Write failed, try to reconnect
50+
log.Warn().Err(err).Msg("Write error detected, attempting to reconnect...")
3551
}
36-
// write failed, try to reconnect
37-
log.Warn().Err(err).Msg("Write error detected, attempting to reconnect...")
52+
53+
// Slow path: reconnect and retry
3854
if err := c.reconnect(); err != nil {
3955
return 0, fmt.Errorf("write failed and reconnection failed: %w", err)
4056
}
4157

42-
// second attempt
43-
n, err = c.TCPConn.Write(bs)
58+
// Retry write after reconnection
59+
conn = c.getConn()
60+
if conn == nil {
61+
return 0, fmt.Errorf("connection still not available after reconnect")
62+
}
63+
64+
n, err := conn.Write(bs)
4465
if err != nil {
4566
return 0, fmt.Errorf("write retry after reconnect failed: %w", err)
4667
}
4768

48-
c.numBytesSent += n
49-
c.numMsgSent++
69+
// Lock-free atomic statistics update
70+
atomic.AddInt64(&c.numBytesSent, int64(n))
71+
atomic.AddInt64(&c.numMsgSent, 1)
5072
log.Trace().Int("nBytes", n).Msg("conn write (after reconnect)")
5173
return n, nil
5274
}
5375

5476
func (c *Connection) Read(bs []byte) (int, error) {
55-
n, err := c.TCPConn.Read(bs)
77+
conn := c.getConn()
78+
if conn == nil {
79+
return 0, fmt.Errorf("connection not available")
80+
}
5681

57-
c.numBytesRecv += n
58-
c.numMsgRecv++
82+
n, err := conn.Read(bs)
83+
84+
// Lock-free atomic statistics update
85+
atomic.AddInt64(&c.numBytesRecv, int64(n))
86+
atomic.AddInt64(&c.numMsgRecv, 1)
5987

6088
log.Trace().Int("nBytes", n).Msg("conn read")
6189

6290
return n, err
6391
}
6492

93+
// getConn returns the current TCP connection in a lock-free way
94+
func (c *Connection) getConn() *net.TCPConn {
95+
return c.tcpConn.Load()
96+
}
97+
98+
// setConn sets the TCP connection in a lock-free way
99+
func (c *Connection) setConn(conn *net.TCPConn) {
100+
c.tcpConn.Store(conn)
101+
}
102+
65103
func (c *Connection) reset() {
66-
c.numBytesSent = 0
67-
c.numBytesRecv = 0
68-
c.numMsgSent = 0
69-
c.numMsgRecv = 0
104+
// Lock-free atomic reset of statistics
105+
atomic.StoreInt64(&c.numBytesSent, 0)
106+
atomic.StoreInt64(&c.numBytesRecv, 0)
107+
atomic.StoreInt64(&c.numMsgSent, 0)
108+
atomic.StoreInt64(&c.numMsgRecv, 0)
70109
}
71110

72111
func (c *Connection) connect(host string, port int) error {
112+
// Protect host/port assignment with mutex to prevent races
113+
c.mu.Lock()
73114
c.host = host
74115
c.port = port
116+
c.mu.Unlock()
117+
75118
c.reset()
76119

77-
address := fmt.Sprintf("%v:%v", c.host, c.port)
120+
// Use the parameters directly instead of reading from struct to avoid races
121+
address := fmt.Sprintf("%v:%v", host, port)
78122
addr, err := net.ResolveTCPAddr("tcp4", address)
79123
if err != nil {
80124
log.Error().Err(err).Str("host", address).Msg("failed to resove tcp address")
81125
c.wrapper.Error(NO_VALID_ID, currentTimeMillis(), FAIL_CREATE_SOCK.Code, FAIL_CREATE_SOCK.Msg, "")
82126
return err
83127
}
84128

85-
c.TCPConn, err = net.DialTCP("tcp4", nil, addr)
129+
newConn, err := net.DialTCP("tcp4", nil, addr)
86130
if err != nil {
87131
log.Error().Err(err).Any("address", addr).Msg("failed to dial tcp")
88132
c.wrapper.Error(NO_VALID_ID, currentTimeMillis(), FAIL_CREATE_SOCK.Code, FAIL_CREATE_SOCK.Msg, "")
89133
return err
90134
}
91135

92-
log.Debug().Any("address", c.TCPConn.RemoteAddr()).Msg("tcp socket connected")
93-
c.isConnected = true
136+
// Atomically update connection state
137+
c.setConn(newConn)
138+
atomic.StoreInt32(&c.isConnected, 1)
94139

140+
log.Debug().Any("address", newConn.RemoteAddr()).Msg("tcp socket connected")
95141
return nil
96142
}
97143

98144
func (c *Connection) reconnect() error {
145+
// Use atomic CAS to prevent multiple concurrent reconnections
146+
if !atomic.CompareAndSwapInt32(&c.reconnecting, 0, 1) {
147+
// Another goroutine is already reconnecting, wait for it
148+
for atomic.LoadInt32(&c.reconnecting) == 1 {
149+
time.Sleep(10 * time.Millisecond)
150+
}
151+
// Check if the other goroutine succeeded
152+
if atomic.LoadInt32(&c.isConnected) == 1 {
153+
return nil
154+
}
155+
return fmt.Errorf("concurrent reconnection failed")
156+
}
157+
158+
// Ensure we clear the reconnecting flag when done
159+
defer atomic.StoreInt32(&c.reconnecting, 0)
160+
99161
var err error
100162
backoff := reconnectDelay // Start with base delay
101163

@@ -106,35 +168,62 @@ func (c *Connection) reconnect() error {
106168
Int("maxAttempts", maxReconnectAttempts).
107169
Msg("Attempting to reconnect")
108170

109-
err = c.connect(c.host, c.port)
171+
// Read host/port atomically to avoid race
172+
c.mu.RLock()
173+
host, port := c.host, c.port
174+
c.mu.RUnlock()
175+
176+
err = c.connect(host, port)
110177
if err == nil {
111178
log.Info().Msg("Reconnection successful")
112-
c.isConnected = true
179+
atomic.StoreInt32(&c.isConnected, 1)
113180
return nil
114181
}
115182

116-
// if this isnt our last try, wait and then loop again
183+
// if this isn't our last try, wait and then loop again
117184
if attempt < maxReconnectAttempts {
118185
time.Sleep(backoff)
119186
backoff *= 2
120187
}
121188
}
122189

123190
// if we get here, all attempts failed
124-
c.isConnected = false
191+
atomic.StoreInt32(&c.isConnected, 0)
125192
return fmt.Errorf("failed to reconnect after %d attempts: %w", maxReconnectAttempts, err)
126-
127193
}
128194

129195
func (c *Connection) disconnect() error {
196+
// Load statistics atomically for logging
197+
msgSent := atomic.LoadInt64(&c.numMsgSent)
198+
bytesSent := atomic.LoadInt64(&c.numBytesSent)
199+
msgRecv := atomic.LoadInt64(&c.numMsgRecv)
200+
bytesRecv := atomic.LoadInt64(&c.numBytesRecv)
201+
130202
log.Trace().
131-
Int("nMsgSent", c.numMsgSent).Int("nBytesSent", c.numBytesSent).
132-
Int("nMsgRecv", c.numMsgRecv).Int("nBytesRecv", c.numBytesRecv).
203+
Int64("nMsgSent", msgSent).Int64("nBytesSent", bytesSent).
204+
Int64("nMsgRecv", msgRecv).Int64("nBytesRecv", bytesRecv).
133205
Msg("conn disconnect")
134-
c.isConnected = false
135-
return c.Close()
206+
207+
// Atomically mark as disconnected
208+
atomic.StoreInt32(&c.isConnected, 0)
209+
210+
// Close the connection
211+
conn := c.getConn()
212+
if conn != nil {
213+
c.setConn(nil)
214+
return conn.Close()
215+
}
216+
return nil
136217
}
137218

138219
func (c *Connection) IsConnected() bool {
139-
return c.isConnected
220+
return atomic.LoadInt32(&c.isConnected) == 1
221+
}
222+
223+
// GetStatistics returns current connection statistics atomically
224+
func (c *Connection) GetStatistics() (bytesSent, msgSent, bytesRecv, msgRecv int64) {
225+
return atomic.LoadInt64(&c.numBytesSent),
226+
atomic.LoadInt64(&c.numMsgSent),
227+
atomic.LoadInt64(&c.numBytesRecv),
228+
atomic.LoadInt64(&c.numMsgRecv)
140229
}

0 commit comments

Comments
 (0)