Skip to content

Commit 7865e50

Browse files
committed
Try terminating connections earlier.
1 parent 0b280cb commit 7865e50

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

pkg/domainproxy/common/common.go

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func (c *Common) NewSharedParams() SharedParams {
5757

5858
func (c *Common) BiDirectionalTransfer(runningContext context.Context, leftConnection, rightConnection net.Conn, byteBufferSize int, idleTimeout time.Duration, connectionType string, connectionNo uint64) {
5959
defer c.CloseConnection(leftConnection, rightConnection, connectionType, connectionNo)
60-
done := make(chan struct{}, 2)
60+
transferContext, terminateTransfer := context.WithCancel(runningContext)
6161
if err := leftConnection.SetDeadline(time.Now().Add(idleTimeout)); err != nil {
6262
c.HandleSetDeadlineError(leftConnection, err)
6363
return
@@ -66,20 +66,17 @@ func (c *Common) BiDirectionalTransfer(runningContext context.Context, leftConne
6666
c.HandleSetDeadlineError(rightConnection, err)
6767
return
6868
}
69-
go c.Transfer(runningContext, leftConnection, rightConnection, done, byteBufferSize, idleTimeout, connectionType, connectionNo)
70-
go c.Transfer(runningContext, rightConnection, leftConnection, done, byteBufferSize, idleTimeout, connectionType, connectionNo)
71-
<-done
72-
<-done
69+
go c.Transfer(transferContext, terminateTransfer, leftConnection, rightConnection, byteBufferSize, idleTimeout, connectionType, connectionNo)
70+
go c.Transfer(transferContext, terminateTransfer, rightConnection, leftConnection, byteBufferSize, idleTimeout, connectionType, connectionNo)
71+
<-transferContext.Done()
7372
}
7473

75-
func (c *Common) Transfer(runningContext context.Context, sourceConnection, targetConnection net.Conn, done chan struct{}, bufferSize int, idleTimeout time.Duration, connectionType string, connectionNo uint64) {
76-
defer func() {
77-
done <- struct{}{}
78-
}()
74+
func (c *Common) Transfer(transferContext context.Context, terminateTransfer context.CancelFunc, sourceConnection, targetConnection net.Conn, bufferSize int, idleTimeout time.Duration, connectionType string, connectionNo uint64) {
75+
defer terminateTransfer()
7976
buf := make([]byte, bufferSize)
8077
for {
8178
select {
82-
case <-runningContext.Done():
79+
case <-transferContext.Done():
8380
return
8481
default:
8582
if n, err := io.CopyBuffer(sourceConnection, targetConnection, buf); err != nil {
@@ -95,6 +92,8 @@ func (c *Common) Transfer(runningContext context.Context, sourceConnection, targ
9592
return
9693
}
9794
c.logger.Printf("%d bytes transferred for %s connection %d", n, connectionType, connectionNo)
95+
} else {
96+
return
9897
}
9998
}
10099
}
@@ -117,10 +116,12 @@ func (c *Common) HandleListenerCloseError(err error) {
117116

118117
func (c *Common) handleConnectionError(err error, connectionType string, connectionNo uint64) {
119118
var netErr net.Error
120-
if errors.As(err, &netErr) && netErr.Timeout() {
121-
c.logger.Printf("%s connection %d timed out", connectionType, connectionNo)
122-
} else if err != io.EOF {
123-
c.logger.Printf("Failed to transfer data using %s connection %d: %v", connectionType, connectionNo, err)
119+
if !errors.Is(err, net.ErrClosed) {
120+
if errors.As(err, &netErr) && netErr.Timeout() {
121+
c.logger.Printf("%s connection %d timed out", connectionType, connectionNo)
122+
} else if err != io.EOF {
123+
c.logger.Printf("Failed to transfer data using %s connection %d: %v", connectionType, connectionNo, err)
124+
}
124125
}
125126
}
126127

0 commit comments

Comments
 (0)