Skip to content

Commit 9015c7a

Browse files
committed
connect
1 parent 21c5fb3 commit 9015c7a

File tree

5 files changed

+81
-27
lines changed

5 files changed

+81
-27
lines changed

client.go

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"bytes"
1515
"context"
1616
"encoding/binary"
17+
"errors"
1718
"fmt"
1819
"os"
1920
"os/signal"
@@ -48,24 +49,25 @@ func (cs ConnState) String() string {
4849

4950
// EClient is the main struct to use from API user's point of view.
5051
type EClient struct {
51-
host string
52-
port int
53-
clientID int64
54-
connectOptions string
55-
conn *Connection
56-
serverVersion Version
57-
connTime string
58-
connState ConnState
59-
writer *bufio.Writer
60-
scanner *bufio.Scanner
61-
wrapper EWrapper
62-
decoder *EDecoder
63-
reqChan chan []byte
64-
Ctx context.Context
65-
Cancel context.CancelFunc
66-
extraAuth bool
67-
wg sync.WaitGroup
68-
err error
52+
host string
53+
port int
54+
clientID int64
55+
connectOptions string
56+
optionalCapabilities string
57+
conn *Connection
58+
serverVersion Version
59+
connTime string
60+
connState ConnState
61+
writer *bufio.Writer
62+
scanner *bufio.Scanner
63+
wrapper EWrapper
64+
decoder *EDecoder
65+
reqChan chan []byte
66+
Ctx context.Context
67+
Cancel context.CancelFunc
68+
extraAuth bool
69+
wg sync.WaitGroup
70+
err error
6971
}
7072

7173
// NewEClient returns a new Eclient.
@@ -84,6 +86,8 @@ func (c *EClient) reset() {
8486
c.host = ""
8587
c.port = -1
8688
c.clientID = -1
89+
c.connectOptions = ""
90+
c.optionalCapabilities = ""
8791
c.extraAuth = false
8892
c.conn = &Connection{}
8993
c.serverVersion = -1
@@ -143,6 +147,18 @@ func (c *EClient) request() {
143147
}
144148
}
145149
}
150+
func (c *EClient) validateInvalidSymbols(host string) error {
151+
if host != "" && !isASCIIPrintable(host) {
152+
return errors.New(host)
153+
}
154+
if c.connectOptions != "" && !isASCIIPrintable(c.connectOptions) {
155+
return errors.New(c.connectOptions)
156+
}
157+
if c.optionalCapabilities != "" && !isASCIIPrintable(c.optionalCapabilities) {
158+
return errors.New(c.optionalCapabilities)
159+
}
160+
return nil
161+
}
146162

147163
// startAPI initiates the message exchange between the client application and the TWS/IB Gateway.
148164
func (c *EClient) startAPI() error {
@@ -157,7 +173,7 @@ func (c *EClient) startAPI() error {
157173
const VERSION = 2
158174

159175
if c.serverVersion >= MIN_SERVER_VER_OPTIONAL_CAPABILITIES {
160-
msg = makeFields(START_API, VERSION, c.clientID, "")
176+
msg = makeFields(START_API, VERSION, c.clientID, c.optionalCapabilities)
161177
} else {
162178
msg = makeFields(START_API, VERSION, c.clientID)
163179
}
@@ -174,11 +190,19 @@ func (c *EClient) startAPI() error {
174190

175191
// Connect must be called before any other.
176192
// There is no feedback for a successful connection, but a subsequent attempt to connect will return the message "Already connected.".
193+
// You should wait for the connection to be established and NextValidID to be returned before calling any other function. If you don't wait, you will get a broken pipe error.
177194
func (c *EClient) Connect(host string, port int, clientID int64) error {
195+
196+
if err := c.validateInvalidSymbols(host); err != nil {
197+
c.wrapper.Error(NO_VALID_ID, currentTimeMillis(), INVALID_SYMBOL.Code, INVALID_SYMBOL.Msg+err.Error(), "")
198+
return err
199+
}
200+
178201
c.host = host
179202
c.port = port
180203
c.clientID = clientID
181204

205+
// Connecting to IB server
182206
log.Info().Str("host", host).Int("port", port).Int64("clientID", clientID).Msg("Connecting to IB server")
183207
if err := c.conn.connect(c.host, c.port); err != nil {
184208
log.Error().Err(CONNECT_FAIL).Msg("Connection fail")
@@ -188,7 +212,7 @@ func (c *EClient) Connect(host string, port int, clientID int64) error {
188212
}
189213

190214
// HandShake with the TWS or GateWay to ensure the version,
191-
log.Debug().Msg("HandShake with TWS or GateWay")
215+
log.Debug().Msg("Handshake with TWS or GateWay")
192216

193217
head := []byte("API\x00")
194218

@@ -197,7 +221,7 @@ func (c *EClient) Connect(host string, port int, clientID int64) error {
197221
connectOptions = " " + c.connectOptions
198222
}
199223
sizeofCV := make([]byte, 4)
200-
clientVersion := []byte(fmt.Sprintf("v%d..%d%s", MIN_CLIENT_VER, MAX_CLIENT_VER, connectOptions))
224+
clientVersion := fmt.Appendf(nil, "v%d..%d%s", MIN_CLIENT_VER, MAX_CLIENT_VER, connectOptions)
201225

202226
binary.BigEndian.PutUint32(sizeofCV, uint32(len(clientVersion)))
203227

@@ -206,7 +230,7 @@ func (c *EClient) Connect(host string, port int, clientID int64) error {
206230
msg.Write(sizeofCV)
207231
msg.Write(clientVersion)
208232

209-
log.Debug().Bytes("header", msg.Bytes()).Msg("send handShake header")
233+
log.Debug().Bytes("header", msg.Bytes()).Msg("Sending handshake header")
210234

211235
if _, err := c.writer.Write(msg.Bytes()); err != nil {
212236
return err
@@ -216,7 +240,7 @@ func (c *EClient) Connect(host string, port int, clientID int64) error {
216240
return err
217241
}
218242

219-
log.Debug().Msg("recv handShake Info")
243+
log.Debug().Msg("Receiving handshake Info")
220244

221245
// scan once to get server info
222246
if !c.scanner.Scan() {
@@ -292,9 +316,19 @@ func (c *EClient) IsConnected() bool {
292316
return c.conn.IsConnected() && c.connState == CONNECTED
293317
}
294318

319+
// OptionalCapabilities returns the Optional Capabilities.
320+
func (c *EClient) OptionalCapabilities() string {
321+
return c.optionalCapabilities
322+
}
323+
324+
// SetOptionalCapabilities setup the Optional Capabilities.
325+
func (c *EClient) SetOptionalCapabilities(optCapts string) {
326+
c.optionalCapabilities = optCapts
327+
}
328+
295329
// SetConnectionOptions setup the Connection Options.
296-
func (c *EClient) SetConnectionOptions(opts string) {
297-
c.connectOptions = opts
330+
func (c *EClient) SetConnectionOptions(connectOptions string) {
331+
c.connectOptions = connectOptions
298332
}
299333

300334
// ReqCurrentTime asks the current system time on the server side.

client_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ func setupIBClient(t *testing.T) *EClient {
4343
t.Fatalf("Couldn't connect EClient: %v", err)
4444
}
4545

46+
// Add a short delay to allow the connection to stabilize
47+
time.Sleep(100 * time.Millisecond)
48+
4649
return globalIB
4750
}
4851

examples/basic/basic.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99

1010
const (
1111
host = "localhost"
12-
port = 7496
12+
port = 7497
1313
)
1414

1515
var orderID int64
@@ -34,6 +34,10 @@ func main() {
3434
return
3535
}
3636

37+
// Add a short delay to allow the connection to stabilize
38+
time.Sleep(100 * time.Millisecond)
39+
log.Info().Msg("Waited for connection to stabilize")
40+
3741
// ib.SetConnectionOptions("+PACEAPI")
3842

3943
// Logger test
@@ -49,7 +53,7 @@ func main() {
4953
// log.Print("TWS Connection time: ", ib.TWSConnectionTime())
5054

5155
// time.Sleep(1 * time.Second)
52-
ib.ReqCurrentTime()
56+
// ib.ReqCurrentTime()
5357

5458
// ########## account ##########
5559
ib.ReqManagedAccts()

examples/blocking/blocking.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ func main() {
6262
return
6363
}
6464
defer ib.Disconnect()
65+
// Add a short delay to allow the connection to stabilize
66+
time.Sleep(100 * time.Millisecond)
67+
log.Info().Msg("Waited for connection to stabilize")
6568

6669
// Request servert current time
6770
t := ib.ReqCurrentTime()

utils.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,13 @@ func GetTimeStrFromMillis(timestamp int64) string {
298298
}
299299
return ""
300300
}
301+
302+
// isASCIIPrintable checks if all characters in the given string are ASCII printable characters.
303+
func isASCIIPrintable(s string) bool {
304+
for _, r := range s {
305+
if r == '\t' || r == '\n' || r == '\r' || r < 32 || r > 126 {
306+
return false
307+
}
308+
}
309+
return true
310+
}

0 commit comments

Comments
 (0)