diff --git a/.gitignore b/.gitignore index 20bd640..26fa58b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /.idea/ /deps /.gobin/ +/vendor/ diff --git a/client/side_channel_creds.go b/client/side_channel_creds.go index 689fde1..fcefba6 100644 --- a/client/side_channel_creds.go +++ b/client/side_channel_creds.go @@ -15,8 +15,12 @@ package client import ( + "bufio" "context" + "fmt" "net" + "net/http" + "net/url" "sync" "google.golang.org/grpc/credentials" @@ -47,7 +51,24 @@ func (c *sideChannelCreds) ClientHandshake(ctx context.Context, authority string return rawConn, c.authInfo, nil } - sideChannelConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", c.endpoint) + // check if c.endpoint is reached via proxy + destReq, err := http.NewRequest("GET", "http://"+c.endpoint, nil) + if err != nil { + return nil, nil, fmt.Errorf("failed to determine proxy URL for %s: %w", c.endpoint, err) + } + proxyURL, err := http.ProxyFromEnvironment(destReq) + if err != nil { + return nil, nil, fmt.Errorf("failed to determine proxy URL for %s: %w", c.endpoint, err) + } + + var sideChannelConn net.Conn + if proxyURL != nil { + // net dial via HTTP CONNECT tunnel if using proxy + sideChannelConn, err = dialViaCONNECT(ctx, c.endpoint, proxyURL) + } else { + sideChannelConn, err = new(net.Dialer).DialContext(ctx, "tcp", c.endpoint) + } + if err != nil { return nil, nil, err } @@ -61,3 +82,28 @@ func (c *sideChannelCreds) ClientHandshake(ctx context.Context, authority string c.authInfo = authInfo return rawConn, authInfo, nil } + +// dialViaCONNECT tunnels a tcp connection to addr through proxy using HTTP CONNECT +func dialViaCONNECT(ctx context.Context, addr string, proxy *url.URL) (net.Conn, error) { + proxyAddr := proxy.Host + if proxy.Port() == "" { + proxyAddr = net.JoinHostPort(proxyAddr, "80") + } + conn, err := new(net.Dialer).DialContext(ctx, "tcp", proxyAddr) + if err != nil { + return nil, fmt.Errorf("failed to dial proxy %s: %w", proxyAddr, err) + } + fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", addr, proxy.Hostname()) + rr := bufio.NewReader(conn) + res, err := http.ReadResponse(rr, nil) + if err != nil { + return nil, fmt.Errorf("failed to read response from HTTP CONNECT to %s via proxy %s: %w", addr, proxyAddr, err) + } + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to dial %s via %s. response status: %v", addr, proxyAddr, res.Status) + } + if rr.Buffered() > 0 { + return nil, fmt.Errorf("CONNECT response from %s resulted in %d bytes of unexpected data", proxyAddr, rr.Buffered()) + } + return conn, nil +} diff --git a/client/ws_proxy.go b/client/ws_proxy.go index 2c7d357..5bedb1e 100644 --- a/client/ws_proxy.go +++ b/client/ws_proxy.go @@ -299,6 +299,7 @@ func createClientWSProxy(endpoint string, tlsClientConf *tls.Config) (*http.Serv httpClient: &http.Client{ Transport: &http.Transport{ TLSClientConfig: tlsClientConf, + Proxy: http.ProxyFromEnvironment, }, }, }