Skip to content

Commit 6b4fb8c

Browse files
committed
Acceptor can listen on different ports per session
1 parent 4717f65 commit 6b4fb8c

File tree

1 file changed

+52
-36
lines changed

1 file changed

+52
-36
lines changed

acceptor.go

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
"strconv"
1111
"sync"
1212

13-
"github.com/armon/go-proxyproto"
13+
proxyproto "github.com/armon/go-proxyproto"
1414
"github.com/quickfixgo/quickfix/config"
1515
)
1616

@@ -23,13 +23,14 @@ type Acceptor struct {
2323
globalLog Log
2424
sessions map[SessionID]*session
2525
sessionGroup sync.WaitGroup
26-
listener net.Listener
2726
listenerShutdown sync.WaitGroup
2827
dynamicSessions bool
2928
dynamicQualifier bool
3029
dynamicQualifierCount int
3130
dynamicSessionChan chan *session
3231
sessionAddr map[SessionID]net.Addr
32+
sessionHostPort map[SessionID]int
33+
listeners map[string]net.Listener
3334
connectionValidator ConnectionValidator
3435
sessionFactory
3536
}
@@ -42,46 +43,49 @@ type ConnectionValidator interface {
4243
}
4344

4445
//Start accepting connections.
45-
func (a *Acceptor) Start() error {
46+
func (a *Acceptor) Start() (err error) {
4647
socketAcceptHost := ""
4748
if a.settings.GlobalSettings().HasSetting(config.SocketAcceptHost) {
48-
var err error
4949
if socketAcceptHost, err = a.settings.GlobalSettings().Setting(config.SocketAcceptHost); err != nil {
50-
return err
50+
return
5151
}
5252
}
5353

54-
socketAcceptPort, err := a.settings.GlobalSettings().IntSetting(config.SocketAcceptPort)
55-
if err != nil {
56-
return err
54+
a.sessionHostPort = make(map[SessionID]int)
55+
a.listeners = make(map[string]net.Listener)
56+
for sessionID, sessionSettings := range a.settings.SessionSettings() {
57+
if sessionSettings.HasSetting(config.SocketAcceptPort) {
58+
if a.sessionHostPort[sessionID], err = sessionSettings.IntSetting(config.SocketAcceptPort); err != nil {
59+
return
60+
}
61+
} else if a.sessionHostPort[sessionID], err = a.settings.GlobalSettings().IntSetting(config.SocketAcceptPort); err != nil {
62+
return
63+
}
64+
address := net.JoinHostPort(socketAcceptHost, strconv.Itoa(a.sessionHostPort[sessionID]))
65+
a.listeners[address] = nil
5766
}
5867

5968
var tlsConfig *tls.Config
6069
if tlsConfig, err = loadTLSConfig(a.settings.GlobalSettings()); err != nil {
61-
return err
70+
return
6271
}
6372

6473
var useTCPProxy bool
6574
if a.settings.GlobalSettings().HasSetting(config.UseTCPProxy) {
6675
if useTCPProxy, err = a.settings.GlobalSettings().BoolSetting(config.UseTCPProxy); err != nil {
67-
return err
76+
return
6877
}
6978
}
7079

71-
address := net.JoinHostPort(socketAcceptHost, strconv.Itoa(socketAcceptPort))
72-
if tlsConfig != nil {
73-
if a.listener, err = tls.Listen("tcp", address, tlsConfig); err != nil {
74-
return err
75-
}
76-
} else if useTCPProxy {
77-
listener, err := net.Listen("tcp", address)
78-
if err != nil {
79-
return err
80-
}
81-
a.listener = &proxyproto.Listener{Listener: listener}
82-
} else {
83-
if a.listener, err = net.Listen("tcp", address); err != nil {
84-
return err
80+
for address := range a.listeners {
81+
if tlsConfig != nil {
82+
if a.listeners[address], err = tls.Listen("tcp", address, tlsConfig); err != nil {
83+
return
84+
}
85+
} else if a.listeners[address], err = net.Listen("tcp", address); err != nil {
86+
return
87+
} else if useTCPProxy {
88+
a.listeners[address] = &proxyproto.Listener{Listener: a.listeners[address]}
8589
}
8690
}
8791

@@ -101,9 +105,11 @@ func (a *Acceptor) Start() error {
101105
a.sessionGroup.Done()
102106
}()
103107
}
104-
a.listenerShutdown.Add(1)
105-
go a.listenForConnections()
106-
return nil
108+
a.listenerShutdown.Add(len(a.listeners))
109+
for _, listener := range a.listeners {
110+
go a.listenForConnections(listener)
111+
}
112+
return
107113
}
108114

109115
//Stop logs out existing sessions, close their connections, and stop accepting new connections.
@@ -112,7 +118,9 @@ func (a *Acceptor) Stop() {
112118
_ = recover() // suppress sending on closed channel error
113119
}()
114120

115-
a.listener.Close()
121+
for _, listener := range a.listeners {
122+
listener.Close()
123+
}
116124
a.listenerShutdown.Wait()
117125
if a.dynamicSessions {
118126
close(a.dynamicSessionChan)
@@ -132,12 +140,14 @@ func (a *Acceptor) RemoteAddr(sessionID SessionID) (net.Addr, bool) {
132140
//NewAcceptor creates and initializes a new Acceptor.
133141
func NewAcceptor(app Application, storeFactory MessageStoreFactory, settings *Settings, logFactory LogFactory) (a *Acceptor, err error) {
134142
a = &Acceptor{
135-
app: app,
136-
storeFactory: storeFactory,
137-
settings: settings,
138-
logFactory: logFactory,
139-
sessions: make(map[SessionID]*session),
140-
sessionAddr: make(map[SessionID]net.Addr),
143+
app: app,
144+
storeFactory: storeFactory,
145+
settings: settings,
146+
logFactory: logFactory,
147+
sessions: make(map[SessionID]*session),
148+
sessionAddr: make(map[SessionID]net.Addr),
149+
sessionHostPort: make(map[SessionID]int),
150+
listeners: make(map[string]net.Listener),
141151
}
142152
if a.settings.GlobalSettings().HasSetting(config.DynamicSessions) {
143153
if a.dynamicSessions, err = settings.globalSettings.BoolSetting(config.DynamicSessions); err != nil {
@@ -171,11 +181,11 @@ func NewAcceptor(app Application, storeFactory MessageStoreFactory, settings *Se
171181
return
172182
}
173183

174-
func (a *Acceptor) listenForConnections() {
184+
func (a *Acceptor) listenForConnections(listener net.Listener) {
175185
defer a.listenerShutdown.Done()
176186

177187
for {
178-
netConn, err := a.listener.Accept()
188+
netConn, err := listener.Accept()
179189
if err != nil {
180190
return
181191
}
@@ -276,6 +286,12 @@ func (a *Acceptor) handleConnection(netConn net.Conn) {
276286
TargetCompID: string(senderCompID), TargetSubID: string(senderSubID), TargetLocationID: string(senderLocationID),
277287
}
278288

289+
localConnectionPort := netConn.LocalAddr().(*net.TCPAddr).Port
290+
if expectedPort, ok := a.sessionHostPort[sessID]; !ok || expectedPort != localConnectionPort {
291+
a.globalLog.OnEventf("Session %v not found for incoming message: %s", sessID, msgBytes)
292+
return
293+
}
294+
279295
// We have a session ID and a network connection. This seems to be a good place for any custom authentication logic.
280296
if a.connectionValidator != nil {
281297
if err := a.connectionValidator.Validate(netConn, sessID); err != nil {

0 commit comments

Comments
 (0)