diff --git a/proxy/proxy.go b/proxy/proxy.go index 64faaa4..48616b7 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -96,6 +96,7 @@ func NewProxyServer(config *ProxyConfig) (ProxyServer, error) { proxy := goproxy.NewProxyHttpServer() proxy.Logger = &goproxyLoggerWrapper{} + proxy.Tr = newUpstreamTransport(config) // Set verbose to true for verbose logging. // Logging is handled by our own logger which has log level controls. @@ -131,6 +132,24 @@ func NewProxyServer(config *ProxyConfig) (ProxyServer, error) { return ps, nil } +func newUpstreamTransport(config *ProxyConfig) *http.Transport { + dialer := &net.Dialer{ + Timeout: config.ConnectTimeout, + } + + // Keep transport behavior close to goproxy defaults and only harden TLS: + // enforce server certificate verification and require TLS 1.2+. + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: dialer.DialContext, + TLSHandshakeTimeout: config.ConnectTimeout, + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: false, + }, + } +} + func (ps *proxyServer) Start() error { listener, err := net.Listen("tcp", ps.config.ListenAddr) if err != nil { diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go new file mode 100644 index 0000000..64d0f45 --- /dev/null +++ b/proxy/proxy_test.go @@ -0,0 +1,53 @@ +package proxy + +import ( + "crypto/tls" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewProxyServerSecuresUpstreamTLSConfig(t *testing.T) { + server, err := NewProxyServer(&ProxyConfig{ + ListenAddr: "127.0.0.1:0", + EnableMITM: false, + ConnectTimeout: 30 * time.Second, + RequestTimeout: 5 * time.Minute, + }) + assert.NoError(t, err) + + internalProxy, ok := server.(*proxyServer) + assert.True(t, ok) + assert.NotNil(t, internalProxy.proxy.Tr) + assert.NotNil(t, internalProxy.proxy.Tr.TLSClientConfig) + assert.False(t, internalProxy.proxy.Tr.TLSClientConfig.InsecureSkipVerify, "upstream TLS verification must stay enabled") + assert.GreaterOrEqual(t, internalProxy.proxy.Tr.TLSClientConfig.MinVersion, uint16(tls.VersionTLS12), "minimum TLS version should be 1.2+") +} + +func TestNewProxyServerRejectsUntrustedUpstreamCertByDefault(t *testing.T) { + target := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer target.Close() + + server, err := NewProxyServer(&ProxyConfig{ + ListenAddr: "127.0.0.1:0", + EnableMITM: false, + ConnectTimeout: 30 * time.Second, + RequestTimeout: 5 * time.Minute, + }) + assert.NoError(t, err) + + internalProxy, ok := server.(*proxyServer) + assert.True(t, ok) + + req, err := http.NewRequest(http.MethodGet, target.URL, nil) + assert.NoError(t, err) + + resp, err := internalProxy.proxy.Tr.RoundTrip(req) + assert.Error(t, err, "untrusted upstream certificate should fail verification") + assert.Nil(t, resp) +}