diff --git a/acceptor.go b/acceptor.go index f58ef01f7..0228778c1 100644 --- a/acceptor.go +++ b/acceptor.go @@ -18,6 +18,7 @@ package quickfix import ( "bufio" "bytes" + "context" "crypto/tls" "io" "net" @@ -361,6 +362,7 @@ func (a *Acceptor) handleConnection(netConn net.Conn) { a.sessionAddr.Store(sessID, netConn.RemoteAddr()) msgIn := make(chan fixIn) msgOut := make(chan []byte) + ctx := context.Background() if err := session.connect(msgIn, msgOut); err != nil { a.globalLog.OnEventf("Unable to accept session %v connection: %v", sessID, err.Error()) @@ -369,10 +371,10 @@ func (a *Acceptor) handleConnection(netConn net.Conn) { go func() { msgIn <- fixIn{msgBytes, parser.lastRead} - readLoop(parser, msgIn, a.globalLog) + readLoop(ctx, parser, msgIn, a.globalLog) }() - writeLoop(netConn, msgOut, a.globalLog) + writeLoop(ctx, netConn, msgOut, a.globalLog) } func (a *Acceptor) dynamicSessionsLoop() { diff --git a/connection.go b/connection.go index 99a4c465e..95e77b239 100644 --- a/connection.go +++ b/connection.go @@ -15,10 +15,19 @@ package quickfix -import "io" +import ( + "context" + "io" +) -func writeLoop(connection io.Writer, messageOut chan []byte, log Log) { +func writeLoop(ctx context.Context, connection io.Writer, messageOut chan []byte, log Log) { for { + select { + case <-ctx.Done(): + return + default: + } + msg, ok := <-messageOut if !ok { return @@ -30,10 +39,16 @@ func writeLoop(connection io.Writer, messageOut chan []byte, log Log) { } } -func readLoop(parser *parser, msgIn chan fixIn, log Log) { +func readLoop(ctx context.Context, parser *parser, msgIn chan fixIn, log Log) { defer close(msgIn) for { + select { + case <-ctx.Done(): + return + default: + } + msg, err := parser.ReadMessage() if err != nil { log.OnEvent(err.Error()) diff --git a/connection_internal_test.go b/connection_internal_test.go index 081b3c110..daa045dd0 100644 --- a/connection_internal_test.go +++ b/connection_internal_test.go @@ -17,11 +17,13 @@ package quickfix import ( "bytes" + "context" "strings" "testing" ) func TestWriteLoop(t *testing.T) { + ctx := context.Background() writer := bytes.NewBufferString("") msgOut := make(chan []byte) @@ -31,7 +33,7 @@ func TestWriteLoop(t *testing.T) { msgOut <- []byte("test msg 3") close(msgOut) }() - writeLoop(writer, msgOut, nullLog{}) + writeLoop(ctx, writer, msgOut, nullLog{}) expected := "test msg 1 test msg 2 test msg 3" @@ -41,11 +43,12 @@ func TestWriteLoop(t *testing.T) { } func TestReadLoop(t *testing.T) { + ctx := context.Background() msgIn := make(chan fixIn) stream := "hello8=FIX.4.09=5blah10=103garbage8=FIX.4.09=4foo10=103" parser := newParser(strings.NewReader(stream)) - go readLoop(parser, msgIn, nullLog{}) + go readLoop(ctx, parser, msgIn, nullLog{}) var tests = []struct { expectedMsg string diff --git a/initiator.go b/initiator.go index 18451477e..991001320 100644 --- a/initiator.go +++ b/initiator.go @@ -163,14 +163,17 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di return } - ctx, cancel := context.WithCancel(context.Background()) + ctx := context.Background() + dialCtx, dialCancel := context.WithCancel(ctx) + readWriteCtx, readWriteCancel := context.WithCancel(ctx) // We start a goroutine in order to be able to cancel the dialer mid-connection // on receiving a stop signal to stop the initiator. go func() { select { case <-i.stopChan: - cancel() + dialCancel() + readWriteCancel() case <-ctx.Done(): return } @@ -183,7 +186,7 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di address := session.SocketConnectAddress[connectionAttempt%len(session.SocketConnectAddress)] session.log.OnEventf("Connecting to: %v", address) - netConn, err := dialer.DialContext(ctx, "tcp", address) + netConn, err := dialer.DialContext(dialCtx, "tcp", address) if err != nil { session.log.OnEventf("Failed to connect: %v", err) goto reconnect @@ -207,24 +210,26 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di msgIn = make(chan fixIn) msgOut = make(chan []byte) - if err := session.connect(msgIn, msgOut); err != nil { - session.log.OnEventf("Failed to initiate: %v", err) - goto reconnect - } + - go readLoop(newParser(bufio.NewReader(netConn)), msgIn, session.log) + go readLoop(readWriteCtx,newParser(bufio.NewReader(netConn)), msgIn, session.log) disconnected = make(chan interface{}) go func() { - writeLoop(netConn, msgOut, session.log) + writeLoop(readWriteCtx,netConn, msgOut, session.log) if err := netConn.Close(); err != nil { session.log.OnEvent(err.Error()) } close(disconnected) }() + if err := session.connect(msgIn, msgOut); err != nil { + session.log.OnEventf("Failed to initiate: %v", err) + goto reconnect + } + // This ensures we properly cleanup the goroutine and context used for // dial cancelation after successful connection. - cancel() + dialCancel() select { case <-disconnected: @@ -233,7 +238,7 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di } reconnect: - cancel() + dialCancel() connectionAttempt++ session.log.OnEventf("Reconnecting in %v", session.ReconnectInterval) diff --git a/session.go b/session.go index a6d296999..27a06b625 100644 --- a/session.go +++ b/session.go @@ -849,15 +849,15 @@ func (s *session) onAdmin(msg interface{}) { return } - if msg.err != nil { - close(msg.err) - } - s.messageIn = msg.messageIn s.messageOut = msg.messageOut s.sentReset = false - s.Connect(s) + err := s.Connect(s) + if msg.err != nil { + msg.err <- err + close(msg.err) + } case stopReq: s.Stop(s) diff --git a/session_state.go b/session_state.go index 6fe4dded7..509803a0e 100644 --- a/session_state.go +++ b/session_state.go @@ -36,36 +36,37 @@ func (sm *stateMachine) Start(s *session) { sm.CheckSessionTime(s, time.Now()) } -func (sm *stateMachine) Connect(session *session) { +func (sm *stateMachine) Connect(session *session) error{ // No special logon logic needed for FIX Acceptors. if !session.InitiateLogon { sm.setState(session, logonState{}) - return + return nil } if session.RefreshOnLogon { if err := session.store.Refresh(); err != nil { session.logError(err) - return + return err } } if session.ResetOnLogon { if err := session.store.Reset(); err != nil { session.logError(err) - return + return err } } session.log.OnEvent("Sending logon request") if err := session.sendLogon(); err != nil { session.logError(err) - return + return err } sm.setState(session, logonState{}) // Fire logon timeout event after the pre-configured delay period. time.AfterFunc(session.LogonTimeout, func() { session.sessionEvent <- internal.LogonTimeout }) + return nil } func (sm *stateMachine) Stop(session *session) {