Skip to content

Commit 8d368db

Browse files
committed
reduce number of keep-alive UDP packets #13
1 parent c391d38 commit 8d368db

File tree

4 files changed

+124
-89
lines changed

4 files changed

+124
-89
lines changed

tsshd/bus.go

Lines changed: 38 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,16 @@ var busClosing atomic.Bool
4343
var busClosingMu sync.Mutex
4444
var busClosingWG sync.WaitGroup
4545

46-
var globalActiveChecker *timeoutChecker
46+
var clientAliveTime aliveTime
47+
var pendingClearPktCache bool
4748

4849
func sendBusMessage(command string, msg any) error {
4950
busMutex.Lock()
5051
defer busMutex.Unlock()
5152
if busStream == nil {
5253
return fmt.Errorf("bus stream is nil")
5354
}
54-
if err := sendCommand(busStream, command); err != nil {
55-
return err
56-
}
57-
return sendMessage(busStream, msg)
55+
return sendCommandAndMessage(busStream, command, msg)
5856
}
5957

6058
func initBusStream(stream Stream) error {
@@ -101,27 +99,13 @@ func handleBusEvent(stream Stream) {
10199

102100
serving.Store(true)
103101

104-
globalActiveChecker = newTimeoutChecker(msg.HeartbeatTimeout)
105-
if enableDebugLogging {
106-
globalActiveChecker.onTimeout(func() {
107-
debug("transport offline, last activity at %v", time.UnixMilli(globalActiveChecker.getAliveTime()).Format("15:04:05.000"))
108-
})
109-
globalActiveChecker.onReconnected(func() {
110-
debug("transport resumed, last activity at %v", time.UnixMilli(globalActiveChecker.getAliveTime()).Format("15:04:05.000"))
111-
})
112-
}
113-
globalActiveChecker.onReconnected(func() {
114-
totalSize, totalCount := globalServerProxy.pktCache.clearCache()
115-
if enableDebugLogging {
116-
debug("drop packet cache count [%d] cache size [%d]", totalCount, totalSize)
117-
}
118-
})
102+
intervalTime := int64(msg.IntervalTime / time.Millisecond)
103+
heartbeatTimeout := int64(msg.HeartbeatTimeout / time.Millisecond)
119104

120-
globalServerProxy.clientChecker.timeoutMilli.Store(int64(msg.HeartbeatTimeout / time.Millisecond))
105+
globalServerProxy.clientChecker.timeoutMilli.Store(heartbeatTimeout)
121106

122-
activeAckChan := make(chan int64, 1)
123-
defer close(activeAckChan)
124-
go keepAlive(msg.AliveTimeout, msg.IntervalTime, activeAckChan)
107+
clientAliveTime.addMilli(time.Now().UnixMilli())
108+
go keepAlive(msg.AliveTimeout, msg.IntervalTime)
125109

126110
for {
127111
command, err := recvCommand(stream)
@@ -141,10 +125,8 @@ func handleBusEvent(stream Stream) {
141125
case "close":
142126
handleCloseEvent()
143127
return // return will close the bus stream
144-
case "alive1":
145-
err = handleAlive1Event(stream, activeAckChan)
146-
case "alive2":
147-
err = handleAlive2Event(stream)
128+
case "alive":
129+
err = handleAliveEvent(stream, heartbeatTimeout, intervalTime)
148130
case "setting":
149131
err = handleSettingEvent(stream)
150132
default:
@@ -189,23 +171,39 @@ func handleCloseEvent() {
189171
}()
190172
}
191173

192-
func handleAlive1Event(stream Stream, activeAckChan chan<- int64) error {
174+
func handleAliveEvent(stream Stream, heartbeatTimeout, intervalTime int64) error {
193175
var msg aliveMessage
194176
if err := recvMessage(stream, &msg); err != nil {
195177
return fmt.Errorf("recv alive message failed: %v", err)
196178
}
197179

198-
activeAckChan <- msg.Time
199-
return nil
200-
}
180+
now := time.Now().UnixMilli()
201181

202-
func handleAlive2Event(stream Stream) error {
203-
var msg aliveMessage
204-
if err := recvMessage(stream, &msg); err != nil {
205-
return fmt.Errorf("recv alive message failed: %v", err)
182+
// If the time since the last recorded activity exceeds heartbeatTimeout,
183+
// it indicates that the client was previously disconnected and has now reconnected.
184+
// Set the flag to clear the packet cache after the client stabilizes.
185+
if now-clientAliveTime.latest() > heartbeatTimeout {
186+
debug("client reconnected, last active at %v", time.UnixMilli(clientAliveTime.latest()).Format("15:04:05.000"))
187+
pendingClearPktCache = true
188+
} else if enableDebugLogging && pendingClearPktCache {
189+
debug("client active at %v", time.UnixMilli(now).Format("15:04:05.000"))
190+
}
191+
192+
clientAliveTime.addMilli(now)
193+
194+
if pendingClearPktCache {
195+
// If the client has remained active for a sufficient number of intervals,
196+
// consider the connection stable and clear the packet cache.
197+
if now-clientAliveTime.oldest() < (kAliveTimeCap+1)*intervalTime {
198+
totalSize, totalCount := globalServerProxy.pktCache.clearCache()
199+
if enableDebugLogging && (totalSize > 0 || totalCount > 0) {
200+
debug("drop packet cache count [%d] size [%d]", totalCount, totalSize)
201+
}
202+
pendingClearPktCache = false
203+
}
206204
}
207205

208-
return sendBusMessage("alive2", msg)
206+
return sendBusMessage("alive", msg)
209207
}
210208

211209
func handleSettingEvent(stream Stream) error {
@@ -230,29 +228,10 @@ func handleUnknownEvent(stream Stream, command string) error {
230228
return fmt.Errorf("unknown command: %s", command)
231229
}
232230

233-
func keepAlive(totalTimeout time.Duration, intervalTime time.Duration, activeAckChan <-chan int64) {
234-
ticker := time.NewTicker(intervalTime)
235-
defer ticker.Stop()
236-
go func() {
237-
for range ticker.C {
238-
aliveTime := time.Now().UnixMilli()
239-
if enableDebugLogging && globalActiveChecker.isTimeout() {
240-
debug("sending new keep alive [%d]", aliveTime)
241-
}
242-
if err := sendBusMessage("alive1", aliveMessage{aliveTime}); err != nil {
243-
warning("send keep alive [%d] failed: %v", aliveTime, err)
244-
} else if enableDebugLogging && globalActiveChecker.isTimeout() {
245-
debug("keep alive [%d] sent success", aliveTime)
246-
}
247-
248-
ackTime := <-activeAckChan
249-
globalActiveChecker.updateTime(ackTime)
250-
}
251-
}()
252-
253-
timeoutMilli := int64(totalTimeout / time.Millisecond)
231+
func keepAlive(aliveTimeout time.Duration, intervalTime time.Duration) {
232+
timeoutMilli := int64(aliveTimeout / time.Millisecond)
254233
for {
255-
if time.Now().UnixMilli()-globalActiveChecker.getAliveTime() > timeoutMilli {
234+
if time.Now().UnixMilli()-clientAliveTime.latest() > timeoutMilli {
256235
warning("tsshd keep alive timeout")
257236
exitChan <- 2
258237
return

tsshd/client.go

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ type SshUdpClient struct {
8181
activeAckChan chan int64
8282
reconnectMutex sync.Mutex
8383
reconnectError atomic.Pointer[error]
84+
pendingClearPkt atomic.Bool
8485
}
8586

8687
// UdpClientOptions contains all configuration parameters required to create and initialize a new SshUdpClient
@@ -154,10 +155,10 @@ func NewSshUdpClient(opts *UdpClientOptions) (*SshUdpClient, error) {
154155
})
155156
}
156157
udpClient.activeChecker.onReconnected(func() {
157-
totalSize, totalCount := udpClient.networkProxy.pktCache.clearCache()
158-
if enableDebugLogging {
159-
udpClient.debug("drop packet cache count [%d] cache size [%d]", totalCount, totalSize)
160-
}
158+
// Mark packet cache for deferred clearing.
159+
// The cache is NOT cleared immediately on reconnection because
160+
// the transport may appear reconnected while still being unstable.
161+
udpClient.pendingClearPkt.Store(true)
161162
})
162163
udpClient.activeChecker.onTimeout(udpClient.tryToReconnect)
163164

@@ -474,6 +475,7 @@ func (c *SshUdpClient) tryToReconnect() {
474475
}
475476

476477
func (c *SshUdpClient) keepAlive(intervalTime time.Duration) {
478+
var serverAliveTime aliveTime
477479
ticker := time.NewTicker(intervalTime)
478480
defer ticker.Stop()
479481

@@ -483,19 +485,37 @@ func (c *SshUdpClient) keepAlive(intervalTime time.Duration) {
483485
}
484486

485487
aliveTime := time.Now().UnixMilli()
486-
if c.activeChecker.isTimeout() && enableDebugLogging {
488+
if c.enableDebugging && c.activeChecker.isTimeout() {
487489
c.debug("sending new keep alive [%d]", aliveTime)
488490
}
489-
if err := c.sendBusMessage("alive2", aliveMessage{aliveTime}); err != nil {
491+
if err := c.sendBusMessage("alive", aliveMessage{aliveTime}); err != nil {
490492
if !c.IsClosed() {
491493
c.warning("send keep alive [%d] failed: %v", aliveTime, err)
492494
}
493-
} else if c.activeChecker.isTimeout() && enableDebugLogging {
495+
} else if c.enableDebugging && c.activeChecker.isTimeout() {
494496
c.debug("keep alive [%d] sent success", aliveTime)
495497
}
496498

497499
ackTime := <-c.activeAckChan
498500
c.activeChecker.updateTime(ackTime)
501+
502+
if c.pendingClearPkt.Load() {
503+
if c.enableDebugging {
504+
c.debug("server active at %v", time.UnixMilli(ackTime).Format("15:04:05.000"))
505+
}
506+
507+
serverAliveTime.addMilli(ackTime)
508+
509+
// If the server has remained active for a sufficient number of intervals,
510+
// consider the connection stable and clear the packet cache.
511+
if time.Since(time.UnixMilli(serverAliveTime.oldest())) < time.Duration(kAliveTimeCap+1)*intervalTime {
512+
totalSize, totalCount := c.networkProxy.pktCache.clearCache()
513+
if c.enableDebugging && (totalSize > 0 || totalCount > 0) {
514+
c.debug("drop packet cache count [%d] size [%d]", totalCount, totalSize)
515+
}
516+
c.pendingClearPkt.Store(false)
517+
}
518+
}
499519
}
500520
}
501521

@@ -508,10 +528,7 @@ func (c *SshUdpClient) sendBusCommand(command string) error {
508528
func (c *SshUdpClient) sendBusMessage(command string, msg any) error {
509529
c.busMutex.Lock()
510530
defer c.busMutex.Unlock()
511-
if err := sendCommand(c.busStream, command); err != nil {
512-
return err
513-
}
514-
return sendMessage(c.busStream, msg)
531+
return sendCommandAndMessage(c.busStream, command, msg)
515532
}
516533

517534
func (c *SshUdpClient) handleBusEvent() {
@@ -535,10 +552,8 @@ func (c *SshUdpClient) handleBusEvent() {
535552
c.handleErrorEvent()
536553
case "channel":
537554
c.handleChannelEvent()
538-
case "alive1":
539-
c.handleAlive1Event()
540-
case "alive2":
541-
c.handleAlive2Event()
555+
case "alive":
556+
c.handleAliveEvent()
542557
case "discard":
543558
c.handleDiscardEvent()
544559
default:
@@ -629,21 +644,7 @@ func (c *SshUdpClient) handleChannelEvent() {
629644
}
630645
}
631646

632-
func (c *SshUdpClient) handleAlive1Event() {
633-
var aliveMsg aliveMessage
634-
if err := recvMessage(c.busStream, &aliveMsg); err != nil {
635-
c.warning("recv alive message failed: %v", err)
636-
return
637-
}
638-
639-
if err := c.sendBusMessage("alive1", aliveMsg); err != nil {
640-
if !c.IsClosed() {
641-
c.warning("send alive message failed: %v", err)
642-
}
643-
}
644-
}
645-
646-
func (c *SshUdpClient) handleAlive2Event() {
647+
func (c *SshUdpClient) handleAliveEvent() {
647648
var aliveMsg aliveMessage
648649
if err := recvMessage(c.busStream, &aliveMsg); err != nil {
649650
c.warning("recv alive message failed: %v", err)

tsshd/comm.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,30 @@ func (tc *timeoutChecker) Close() {
348348
}
349349
close(tc.closeChan)
350350
}
351+
352+
const kAliveTimeCap = 10
353+
354+
type aliveTime struct {
355+
mutex sync.Mutex
356+
last int
357+
buf [kAliveTimeCap]int64
358+
}
359+
360+
func (t *aliveTime) addMilli(milli int64) {
361+
t.mutex.Lock()
362+
defer t.mutex.Unlock()
363+
t.last = (t.last + 1) % kAliveTimeCap
364+
t.buf[t.last] = milli
365+
}
366+
367+
func (r *aliveTime) latest() int64 {
368+
r.mutex.Lock()
369+
defer r.mutex.Unlock()
370+
return r.buf[r.last]
371+
}
372+
373+
func (r *aliveTime) oldest() int64 {
374+
r.mutex.Lock()
375+
defer r.mutex.Unlock()
376+
return r.buf[(r.last+1)%kAliveTimeCap]
377+
}

tsshd/proto.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,34 @@ func recvMessage(stream Stream, msg any) error {
292292
return nil
293293
}
294294

295+
func sendCommandAndMessage(stream Stream, command string, msg any) error {
296+
if len(command) == 0 {
297+
return fmt.Errorf("send command is empty")
298+
}
299+
if len(command) > 255 {
300+
return fmt.Errorf("send command too long: %s", command)
301+
}
302+
303+
msgBuf, err := json.Marshal(msg)
304+
if err != nil {
305+
return fmt.Errorf("send message marshal failed: %w", err)
306+
}
307+
308+
totalLen := 1 + len(command) + 4 + len(msgBuf)
309+
buffer := make([]byte, totalLen)
310+
311+
buffer[0] = uint8(len(command))
312+
copy(buffer[1:], []byte(command))
313+
314+
binary.BigEndian.PutUint32(buffer[1+len(command):], uint32(len(msgBuf)))
315+
copy(buffer[1+len(command)+4:], msgBuf)
316+
317+
if err := writeAll(stream, buffer); err != nil {
318+
return fmt.Errorf("send command and message failed: %w", err)
319+
}
320+
return nil
321+
}
322+
295323
func sendError(stream Stream, err error) {
296324
if e := sendMessage(stream, errorMessage{Msg: err.Error()}); e != nil {
297325
warning("send error [%v] failed: %v", err, e)

0 commit comments

Comments
 (0)