Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ require (
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069 // indirect
gopkg.in/inconshreveable/log15.v2 v2.0.0-20200109203555-b30bc20e4fd1
)

replace github.com/imdario/mergo => dario.cat/mergo latest
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably is not a proper solution, but I have to admit that I have no clue about go's dependency management.

32 changes: 15 additions & 17 deletions tunnel/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
package tunnel

import (
"bytes"
"errors"
"fmt"
"html"
"io"
"io/ioutil"
"net/http"
"sync"

// imported per documentation - https://golang.org/pkg/net/http/pprof/
_ "net/http/pprof"
Expand Down Expand Up @@ -77,16 +76,14 @@ func wsHandler(t *WSTunnelServer, w http.ResponseWriter, r *http.Request) {
go func() {
rs.remoteName, rs.remoteWhois = ipAddrLookup(t.Log, rs.remoteAddr)
}()
// Set safety limits
ws.SetReadLimit(100 * 1024 * 1024)
// Start timeout handling
wsSetPingHandler(t, ws, rs)
// Create synchronization channel
ch := make(chan int, 2)
// Spawn goroutine to read responses
go wsReader(rs, ws, t.WSTimeout, ch)
go wsReader(rs, ws, t.WSTimeout, ch, &rs.readWG)
// Send requests
wsWriter(rs, ws, ch)
wsWriter(rs, ws, t.WSTimeout, ch)
}

func wsSetPingHandler(t *WSTunnelServer, ws *websocket.Conn, rs *remoteServer) {
Expand All @@ -111,7 +108,7 @@ func wsSetPingHandler(t *WSTunnelServer, ws *websocket.Conn, rs *remoteServer) {
}

// Pick requests off the RemoteServer queue and send them into the tunnel
func wsWriter(rs *remoteServer, ws *websocket.Conn, ch chan int) {
func wsWriter(rs *remoteServer, ws *websocket.Conn, wsTimeout time.Duration, ch chan int) {
var req *remoteRequest
var err error
for {
Expand All @@ -136,7 +133,7 @@ func wsWriter(rs *remoteServer, ws *websocket.Conn, ch chan int) {
continue
}
// write the request into the tunnel
ws.SetWriteDeadline(time.Now().Add(time.Minute))
ws.SetWriteDeadline(time.Now().Add(wsTimeout))
var w io.WriteCloser
w, err = ws.NextWriter(websocket.BinaryMessage)
// got an error, reply with a "hey, retry" to the request handler
Expand Down Expand Up @@ -170,11 +167,15 @@ func wsWriter(rs *remoteServer, ws *websocket.Conn, ch chan int) {
}

// Read responses from the tunnel and fulfill pending requests
func wsReader(rs *remoteServer, ws *websocket.Conn, wsTimeout time.Duration, ch chan int) {
func wsReader(rs *remoteServer, ws *websocket.Conn, wsTimeout time.Duration, ch chan int, readWG *sync.WaitGroup) {
var err error
logToken := cutToken(rs.token)
// continue reading until we get an error
for {
// wait if another response is being sent
readWG.Wait()
// increment the WaitGroup counter
readWG.Add(1)
ws.SetReadDeadline(time.Time{}) // no timeout, there's the ping-pong for that
// read a message from the tunnel
var t int
Expand All @@ -195,34 +196,31 @@ func wsReader(rs *remoteServer, ws *websocket.Conn, wsTimeout time.Duration, ch
if err != nil {
break
}
// read request itself, the size is limited by the SetReadLimit on the websocket
var buf []byte
buf, err = ioutil.ReadAll(r)
if err != nil {
break
}
rs.log.Info("WS RCV", "id", id, "ws", wsp(ws), "len", len(buf))
// try to match request
rs.requestSetMutex.Lock()
req := rs.requestSet[id]
rs.lastActivity = time.Now()
rs.requestSetMutex.Unlock()
// let's see...
if req != nil {
rb := responseBuffer{response: bytes.NewBuffer(buf)}
rb := responseBuffer{response: r}
// try to enqueue response
select {
case req.replyChan <- rb:
// great!
rs.log.Info("WS RCV enqueued response", "id", id, "ws", wsp(ws))
default:
readWG.Done()
rs.log.Info("WS RCV can't enqueue response", "id", id, "ws", wsp(ws))
}
} else {
readWG.Done()
rs.log.Info("%s #%d: WS RCV orphan response", "id", id, "ws", wsp(ws))
}
}
// print error message
if err != nil {
readWG.Done()
rs.log.Info("WS closing", "token", logToken, "err", err.Error(), "ws", wsp(ws))
}
// close up shop
Expand Down
2 changes: 1 addition & 1 deletion tunnel/wstuncli.go
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ func (wsc *WSConnection) writeResponseMessage(id int16, resp *http.Response) {
wsWriterMutex.Lock()
defer wsWriterMutex.Unlock()
// Write response into the tunnel
wsc.ws.SetWriteDeadline(time.Now().Add(time.Minute))
wsc.ws.SetWriteDeadline(time.Now().Add(wsc.tun.Timeout))
w, err := wsc.ws.NextWriter(websocket.BinaryMessage)
// got an error, reply with a "hey, retry" to the request handler
if err != nil {
Expand Down
11 changes: 7 additions & 4 deletions tunnel/wstunsrv.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type token string

type responseBuffer struct {
err error
response *bytes.Buffer
response io.Reader
}

// A request for a remote server
Expand All @@ -77,6 +77,7 @@ type remoteServer struct {
requestSet map[int16]*remoteRequest // all requests in queue/flight indexed by ID
requestSetMutex sync.Mutex
log log15.Logger
readWG sync.WaitGroup
}

//WSTunnelServer a wstunnel server construct
Expand Down Expand Up @@ -372,7 +373,7 @@ func getResponse(t *WSTunnelServer, req *remoteRequest, w http.ResponseWriter, r
case resp := <-req.replyChan:
// if there's no error just respond
if resp.err == nil {
code := writeResponse(w, resp.response)
code := writeResponse(rs, w, resp.response)
req.log.Info("HTTP RET", "status", code)
return
}
Expand Down Expand Up @@ -498,8 +499,9 @@ var censoredHeaders = []string{
}

// Write an HTTP response from a byte buffer into a ResponseWriter
func writeResponse(w http.ResponseWriter, buf *bytes.Buffer) int {
resp, err := http.ReadResponse(bufio.NewReader(buf), nil)
func writeResponse(rs *remoteServer, w http.ResponseWriter, r io.Reader) int {
defer rs.readWG.Done()
resp, err := http.ReadResponse(bufio.NewReader(r), nil)
if err != nil {
log15.Info("WriteResponse: can't parse incoming response", "err", err)
w.WriteHeader(506)
Expand All @@ -512,6 +514,7 @@ func writeResponse(w http.ResponseWriter, buf *bytes.Buffer) int {
copyHeader(w.Header(), resp.Header)
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
resp.Body.Close()
return resp.StatusCode
}

Expand Down