@@ -57,7 +57,7 @@ func (c *Common) NewSharedParams() SharedParams {
5757
5858func (c * Common ) BiDirectionalTransfer (runningContext context.Context , leftConnection , rightConnection net.Conn , byteBufferSize int , idleTimeout time.Duration , connectionType string , connectionNo uint64 ) {
5959 defer c .CloseConnection (leftConnection , rightConnection , connectionType , connectionNo )
60- done := make ( chan struct {}, 2 )
60+ transferContext , terminateTransfer := context . WithCancel ( runningContext )
6161 if err := leftConnection .SetDeadline (time .Now ().Add (idleTimeout )); err != nil {
6262 c .HandleSetDeadlineError (leftConnection , err )
6363 return
@@ -66,20 +66,17 @@ func (c *Common) BiDirectionalTransfer(runningContext context.Context, leftConne
6666 c .HandleSetDeadlineError (rightConnection , err )
6767 return
6868 }
69- go c .Transfer (runningContext , leftConnection , rightConnection , done , byteBufferSize , idleTimeout , connectionType , connectionNo )
70- go c .Transfer (runningContext , rightConnection , leftConnection , done , byteBufferSize , idleTimeout , connectionType , connectionNo )
71- <- done
72- <- done
69+ go c .Transfer (transferContext , terminateTransfer , leftConnection , rightConnection , byteBufferSize , idleTimeout , connectionType , connectionNo )
70+ go c .Transfer (transferContext , terminateTransfer , rightConnection , leftConnection , byteBufferSize , idleTimeout , connectionType , connectionNo )
71+ <- transferContext .Done ()
7372}
7473
75- func (c * Common ) Transfer (runningContext context.Context , sourceConnection , targetConnection net.Conn , done chan struct {}, bufferSize int , idleTimeout time.Duration , connectionType string , connectionNo uint64 ) {
76- defer func () {
77- done <- struct {}{}
78- }()
74+ func (c * Common ) Transfer (transferContext context.Context , terminateTransfer context.CancelFunc , sourceConnection , targetConnection net.Conn , bufferSize int , idleTimeout time.Duration , connectionType string , connectionNo uint64 ) {
75+ defer terminateTransfer ()
7976 buf := make ([]byte , bufferSize )
8077 for {
8178 select {
82- case <- runningContext .Done ():
79+ case <- transferContext .Done ():
8380 return
8481 default :
8582 if n , err := io .CopyBuffer (sourceConnection , targetConnection , buf ); err != nil {
@@ -95,6 +92,8 @@ func (c *Common) Transfer(runningContext context.Context, sourceConnection, targ
9592 return
9693 }
9794 c .logger .Printf ("%d bytes transferred for %s connection %d" , n , connectionType , connectionNo )
95+ } else {
96+ return
9897 }
9998 }
10099 }
@@ -117,10 +116,12 @@ func (c *Common) HandleListenerCloseError(err error) {
117116
118117func (c * Common ) handleConnectionError (err error , connectionType string , connectionNo uint64 ) {
119118 var netErr net.Error
120- if errors .As (err , & netErr ) && netErr .Timeout () {
121- c .logger .Printf ("%s connection %d timed out" , connectionType , connectionNo )
122- } else if err != io .EOF {
123- c .logger .Printf ("Failed to transfer data using %s connection %d: %v" , connectionType , connectionNo , err )
119+ if ! errors .Is (err , net .ErrClosed ) {
120+ if errors .As (err , & netErr ) && netErr .Timeout () {
121+ c .logger .Printf ("%s connection %d timed out" , connectionType , connectionNo )
122+ } else if err != io .EOF {
123+ c .logger .Printf ("Failed to transfer data using %s connection %d: %v" , connectionType , connectionNo , err )
124+ }
124125 }
125126}
126127
0 commit comments