@@ -61,27 +61,27 @@ type udpClient interface {
6161
6262// SshUdpClient implements a UDP SSH client
6363type SshUdpClient struct {
64- client udpClient
65- proxy * clientProxy
66- connectTimeout time.Duration
67- waitGroup sync.WaitGroup
68- closed atomic.Bool
69- busMutex sync.Mutex
70- busStream net.Conn
71- sessionMutex sync.Mutex
72- sessionID atomic.Uint64
73- sessionMap map [uint64 ]* SshUdpSession
74- channelMutex sync.Mutex
75- channelMap map [string ]chan ssh.NewChannel
76- aliveCallback func (int64 )
77- aliveNewVer bool
78- serverName string
79- quitCallback func (string )
64+ client udpClient
65+ proxy * clientProxy
66+ connectTimeout time.Duration
67+ waitGroup sync.WaitGroup
68+ closed atomic.Bool
69+ busMutex sync.Mutex
70+ busStream net.Conn
71+ sessionMutex sync.Mutex
72+ sessionID atomic.Uint64
73+ sessionMap map [uint64 ]* SshUdpSession
74+ channelMutex sync.Mutex
75+ channelMap map [string ]chan ssh.NewChannel
76+ aliveCallback func (int64 )
77+ aliveNewVer bool
78+ quitCallback func ( string )
79+ discardCallback func ([] byte , [] byte )
8080}
8181
8282// NewSshUdpClient creates a SshUdpClient
8383func NewSshUdpClient (addr string , info * ServerInfo , connectTimeout , aliveTimeout , intervalTime time.Duration ,
84- quitCallback func (string )) (* SshUdpClient , error ) {
84+ quitCallback func (string ), discardCallback func ([] byte , [] byte ) ) (* SshUdpClient , error ) {
8585 var proxy * clientProxy
8686 if info .ProxyKey != "" {
8787 var err error
@@ -99,12 +99,13 @@ func NewSshUdpClient(addr string, info *ServerInfo, connectTimeout, aliveTimeout
9999 return nil , err
100100 }
101101 udpClient := & SshUdpClient {
102- client : client ,
103- proxy : proxy ,
104- sessionMap : make (map [uint64 ]* SshUdpSession ),
105- channelMap : make (map [string ]chan ssh.NewChannel ),
106- connectTimeout : connectTimeout ,
107- quitCallback : quitCallback ,
102+ client : client ,
103+ proxy : proxy ,
104+ sessionMap : make (map [uint64 ]* SshUdpSession ),
105+ channelMap : make (map [string ]chan ssh.NewChannel ),
106+ connectTimeout : connectTimeout ,
107+ quitCallback : quitCallback ,
108+ discardCallback : discardCallback ,
108109 }
109110
110111 busStream , err := udpClient .newStream ("bus" )
@@ -140,27 +141,27 @@ func (c *SshUdpClient) Close() error {
140141
141142 _ , _ = doWithTimeout (func () (int , error ) {
142143 if err := c .sendBusCommand ("close" ); err != nil {
143- debug ("[client] [%s] send cmd [close] failed: %v" , c . serverName , err )
144+ debug ("[client] send cmd [close] failed: %v" , err )
144145 } else {
145- debug ("[client] [%s] send cmd [close] completed" , c . serverName )
146+ debug ("[client] send cmd [close] completed" )
146147 }
147148 // UDP connections do not support half-close (write-only close) for now,
148149 // so we add extra wait time to allow all incoming data to be received.
149150 time .Sleep (200 * time .Millisecond ) // give udp some time
150151 if err := c .busStream .Close (); err != nil {
151- debug ("[client] [%s] close bus stream failed: %v" , c . serverName , err )
152+ debug ("[client] close bus stream failed: %v" , err )
152153 } else {
153- debug ("[client] [%s] close bus stream completed" , c . serverName )
154+ debug ("[client] close bus stream completed" )
154155 }
155156 return 0 , nil
156157 }, 300 * time .Millisecond )
157158
158159 _ , err := doWithTimeout (func () (int , error ) {
159160 err := c .client .closeClient ()
160161 if err != nil {
161- debug ("[client] [%s] close client failed: %v" , c . serverName , err )
162+ debug ("[client] close client failed: %v" , err )
162163 } else {
163- debug ("[client] [%s] close client completed" , c . serverName )
164+ debug ("[client] close client completed" )
164165 }
165166 return 0 , err
166167 }, 200 * time .Millisecond )
@@ -170,16 +171,19 @@ func (c *SshUdpClient) Close() error {
170171
171172// Reconnect creates a new UDP path to the server
172173func (c * SshUdpClient ) Reconnect (timeout time.Duration ) error {
173- if c .proxy != nil {
174- if err := c .proxy .renewUdpPath (timeout ); err != nil {
175- return err
176- }
177- if err := c .sendBusCommand ("alive" ); err != nil { // ping the server
178- return fmt .Errorf ("ping server failed: %w" , err )
179- }
180- return nil
174+ if c .proxy == nil {
175+ return fmt .Errorf ("no proxy for connection migration" )
181176 }
182- return fmt .Errorf ("no proxy for connection migration" )
177+
178+ if err := c .proxy .renewUdpPath (timeout ); err != nil {
179+ return err
180+ }
181+
182+ if err := c .sendBusCommand ("alive" ); err != nil { // ping the server
183+ return fmt .Errorf ("ping server failed: %w" , err )
184+ }
185+
186+ return nil
183187}
184188
185189func (c * SshUdpClient ) newStream (cmd string ) (net.Conn , error ) {
@@ -305,14 +309,13 @@ func (c *SshUdpClient) IsClosed() bool {
305309 return c .closed .Load ()
306310}
307311
308- // SetDebugFunc set the debugging function
309- func (c * SshUdpClient ) SetDebugFunc (svrName string , debugFunc func (string , ... any )) {
310- c .serverName = svrName
312+ // SetDebugFunc sets the debugging function
313+ func (c * SshUdpClient ) SetDebugFunc (debugFunc func (string , ... any )) {
311314 clientDebug = debugFunc
312315 enableDebugLogging = true
313316}
314317
315- // SetWarningFunc set the warning function
318+ // SetWarningFunc sets the warning function
316319func (c * SshUdpClient ) SetWarningFunc (warningFunc func (string , ... any )) {
317320 clientWarning = warningFunc
318321 enableWarningLogging = true
@@ -380,6 +383,11 @@ func (c *SshUdpClient) ForwardUDPv1(addr string, timeout time.Duration) (string,
380383 return localAddr , nil
381384}
382385
386+ // SetKeepPendingInput sets whether to keep the pending input during reconnection.
387+ func (c * SshUdpClient ) SetKeepPendingInput (keep bool ) error {
388+ return c .sendBusMessage ("setting" , settingsMessage {KeepPendingInput : & keep })
389+ }
390+
383391func (c * SshUdpClient ) sendBusCommand (command string ) error {
384392 c .busMutex .Lock ()
385393 defer c .busMutex .Unlock ()
@@ -417,10 +425,12 @@ func (c *SshUdpClient) handleBusEvent() {
417425 c .handleChannelEvent ()
418426 case "alive" :
419427 if c .aliveCallback != nil {
420- c .aliveCallback (0 )
428+ go c .aliveCallback (0 )
421429 }
422430 case "alive2" :
423431 c .handleAliveEvent ()
432+ case "discard" :
433+ c .handleDiscardEvent ()
424434 default :
425435 if err := handleUnknownEvent (c .busStream , command ); err != nil {
426436 warning ("handle bus command [%s] failed: %v. You may need to upgrade tssh." , command , err )
@@ -435,11 +445,11 @@ func (c *SshUdpClient) handleQuitEvent() {
435445 warning ("recv quit message failed: %v" , err )
436446 return
437447 }
438- debug ("[client] [%s] quit due to %s" , c . serverName , quitMsg .Msg )
448+ debug ("[client] quit due to %s" , quitMsg .Msg )
439449 if c .quitCallback != nil {
440- go c .quitCallback (fmt . Sprintf ( "[%s] %s" , c . serverName , quitMsg .Msg ) )
450+ go c .quitCallback (quitMsg .Msg )
441451 } else {
442- warning ("[udp] quit due to [%s] %s" , c . serverName , quitMsg .Msg )
452+ warning ("quit due to %s" , quitMsg .Msg )
443453 }
444454}
445455
@@ -449,7 +459,7 @@ func (c *SshUdpClient) handleExitEvent() {
449459 warning ("recv exit message failed: %v" , err )
450460 return
451461 }
452- debug ("[client] [%s] session [%d] exiting with code: %d" , c . serverName , exitMsg .ID , exitMsg .ExitCode )
462+ debug ("[client] session [%d] exiting with code: %d" , exitMsg .ID , exitMsg .ExitCode )
453463
454464 c .sessionMutex .Lock ()
455465 defer c .sessionMutex .Unlock ()
@@ -471,7 +481,7 @@ func (c *SshUdpClient) handleDebugEvent() {
471481 warning ("recv debug message failed: %v" , err )
472482 return
473483 }
474- debug ("[server] [%s] %s" , c . serverName , dbgMsg .Msg )
484+ debug ("[server] %s" , dbgMsg .Msg )
475485}
476486
477487func (c * SshUdpClient ) handleErrorEvent () {
@@ -480,7 +490,7 @@ func (c *SshUdpClient) handleErrorEvent() {
480490 warning ("recv error message failed: %v" , err )
481491 return
482492 }
483- warning ("[udp] %s" , errMsg .Msg )
493+ warning ("%s" , errMsg .Msg )
484494}
485495
486496func (c * SshUdpClient ) handleChannelEvent () {
@@ -510,11 +520,22 @@ func (c *SshUdpClient) handleAliveEvent() {
510520 return
511521 }
512522 if c .aliveCallback != nil {
513- c .aliveCallback (aliveMsg .Time )
523+ go c .aliveCallback (aliveMsg .Time )
514524 }
515525 c .aliveNewVer = true
516526}
517527
528+ func (c * SshUdpClient ) handleDiscardEvent () {
529+ var discardMsg discardMessage
530+ if err := recvMessage (c .busStream , & discardMsg ); err != nil {
531+ warning ("recv discard message failed: %v" , err )
532+ return
533+ }
534+ if c .discardCallback != nil {
535+ go c .discardCallback (discardMsg .DiscardMarker , discardMsg .DiscardedInput )
536+ }
537+ }
538+
518539// SshUdpSession represents a connection to a remote command or shell
519540type SshUdpSession struct {
520541 id uint64
@@ -549,7 +570,7 @@ func (s *SshUdpSession) Close() error {
549570
550571 _ , err := doWithTimeout (func () (int , error ) {
551572 err := s .stream .Close ()
552- debug ("[client] [%s] close session completed" , s . client . serverName )
573+ debug ("[client] close session completed" )
553574 return 0 , err
554575 }, 100 * time .Millisecond )
555576 return err
@@ -625,14 +646,14 @@ func (s *SshUdpSession) startSession(msg *startMessage) error {
625646 go func () {
626647 _ , _ = io .Copy (s .stream , s .stdin )
627648 _ = s .stream .Close ()
628- debug ("[client] [%s] session [%d] stdin completed" , s . client . serverName , s .id )
649+ debug ("[client] session [%d] stdin completed" , s .id )
629650 }()
630651 }
631652 if s .stdout != nil {
632653 s .wg .Go (func () {
633654 _ , _ = io .Copy (s .stdout , s .stream )
634655 _ = s .stdout .Close ()
635- debug ("[client] [%s] session [%d] stdout completed" , s . client . serverName , s .id )
656+ debug ("[client] session [%d] stdout completed" , s .id )
636657 })
637658 }
638659 return nil
@@ -706,7 +727,7 @@ func (s *SshUdpSession) StderrPipe() (io.Reader, error) {
706727 s .wg .Go (func () {
707728 _ , _ = io .Copy (s .stderr , stream )
708729 _ = s .stderr .Close ()
709- debug ("[client] [%s] session [%d] stderr completed" , s . client . serverName , s .id )
730+ debug ("[client] session [%d] stderr completed" , s .id )
710731 })
711732 return reader , nil
712733}
0 commit comments