Skip to content

Commit 0201247

Browse files
committed
improve robustness and diagnostics
1 parent eab4a26 commit 0201247

File tree

3 files changed

+45
-6
lines changed

3 files changed

+45
-6
lines changed

tsshd/client.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"fmt"
3030
"io"
3131
"net"
32+
"strings"
3233
"sync"
3334
"sync/atomic"
3435
"time"
@@ -135,7 +136,20 @@ func NewSshUdpClient(opts *UdpClientOptions) (*SshUdpClient, error) {
135136
if err != nil {
136137
return nil, err
137138
}
139+
beginTime := time.Now()
138140
if err := udpClient.networkProxy.renewTransportPath(opts.ProxyClient, opts.ConnectTimeout); err != nil {
141+
if opts.ConnectTimeout > 2*time.Second && time.Since(beginTime) > (opts.ConnectTimeout-time.Second) {
142+
net := "UDP"
143+
if opts.ServerInfo.ProxyMode == kProxyModeTCP {
144+
net = "TCP"
145+
}
146+
port := opts.TsshdAddr
147+
if pos := strings.LastIndex(opts.TsshdAddr, ":"); pos >= 0 {
148+
port = opts.TsshdAddr[pos+1:]
149+
}
150+
return nil, fmt.Errorf("%v\r\n%s", err, fmt.Sprintf(
151+
"\033[0;36mHint:\033[0m This may be caused by a firewall blocking the %s port (%s) that tsshd is listening on.", net, port))
152+
}
139153
return nil, err
140154
}
141155

@@ -348,7 +362,7 @@ func (c *SshUdpClient) DialUDP(network, addr string, timeout time.Duration) (Pac
348362
}
349363

350364
c.exitWG.Add(1)
351-
return &sshUdpPacketConn{conn, c}, nil
365+
return &sshUdpPacketConn{packetConn: conn, client: c}, nil
352366
}
353367

354368
// Listen requests the remote peer open a listening socket on addr
@@ -1224,9 +1238,13 @@ func (c *sshUdpChannel) Stderr() io.ReadWriter {
12241238
type sshUdpPacketConn struct {
12251239
*packetConn
12261240
client *SshUdpClient
1241+
closed atomic.Bool
12271242
}
12281243

12291244
func (c *sshUdpPacketConn) Close() error {
1245+
if !c.closed.CompareAndSwap(false, true) {
1246+
return nil
1247+
}
12301248
err := c.packetConn.Close()
12311249
c.client.exitWG.Done()
12321250
return err

tsshd/rekey.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@ func (r *rotatingCrypto) startRekey() {
188188
return
189189
}
190190

191+
if r.client.IsClosed() {
192+
return
193+
}
194+
191195
err := func() error {
192196
curve := ecdh.P256()
193197
clientPriKey, err := curve.GenerateKey(crypto_rand.Reader)

tsshd/session.go

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ type sessionContext struct {
9090
rows int
9191
cmd *exec.Cmd
9292
pty *tsshdPty
93-
wg sync.WaitGroup
93+
outWG sync.WaitGroup
9494
stdin io.WriteCloser
9595
stdout io.ReadCloser
9696
stderr io.ReadCloser
@@ -452,11 +452,11 @@ func (c *sessionContext) forwardIO(stream Stream) {
452452
}
453453

454454
if c.stdout != nil {
455-
c.wg.Go(func() { c.forwardOutput("stdout", c.stdout, stream) })
455+
c.outWG.Go(func() { c.forwardOutput("stdout", c.stdout, stream) })
456456
}
457457

458458
if c.stderr != nil {
459-
c.wg.Go(func() {
459+
c.outWG.Go(func() {
460460
if stderr := getStderrStream(c.id); stderr != nil {
461461
c.forwardOutput("stderr", c.stderr, stderr.stream)
462462
stderr.Close()
@@ -475,17 +475,34 @@ func (c *sessionContext) Wait() {
475475
// windows pty only close the stdout in pty.Wait
476476
if runtime.GOOS == "windows" && c.pty != nil {
477477
_ = c.pty.Wait()
478-
c.wg.Wait()
478+
c.outWG.Wait()
479479
debug("session [%d] wait completed", c.id)
480480
return
481481
}
482482

483-
c.wg.Wait() // wait for the output done first to prevent cmd.Wait close output too early
483+
done := make(chan struct{})
484+
go func() {
485+
c.outWG.Wait() // wait for the output first to prevent cmd.Wait close output too early
486+
close(done)
487+
}()
488+
489+
select {
490+
case <-done:
491+
case <-time.After(time.Second):
492+
}
493+
484494
if c.pty != nil {
485495
_ = c.pty.Wait()
486496
} else {
487497
_ = c.cmd.Wait()
488498
}
499+
500+
select {
501+
case <-done:
502+
case <-time.After(3 * time.Second):
503+
warning("child process has exited, but output streams did not close in time")
504+
}
505+
489506
debug("session [%d] wait completed", c.id)
490507
}
491508

0 commit comments

Comments
 (0)