diff --git a/client.go b/client.go index b5493e9271..32dcb55dd1 100644 --- a/client.go +++ b/client.go @@ -210,6 +210,12 @@ type Client struct { // Default TLS config is used if not set. TLSConfig *tls.Config + // Maximum number of greedy connections per each host which may be established + // before reusing idle connections in the pool. + // + // By default, MaxGreedyConnsPerHost is 0. + MaxGreedyConnsPerHost int + // Maximum number of connections per each host which may be established. // // DefaultMaxConnsPerHost is used if not set. @@ -517,6 +523,7 @@ func (c *Client) Do(req *Request, resp *Response) error { DialDualStack: c.DialDualStack, IsTLS: isTLS, TLSConfig: c.TLSConfig, + MaxGreedyConns: c.MaxGreedyConnsPerHost, MaxConns: c.MaxConnsPerHost, MaxIdleConnDuration: c.MaxIdleConnDuration, MaxConnDuration: c.MaxConnDuration, @@ -726,6 +733,15 @@ type HostClient struct { // Optional TLS config. TLSConfig *tls.Config + // Maximum number of greedy connections which may be established to all hosts + // listed in Addr. + // If it is set, the HostClient.acquireConn with create new connection + // other than reusing idle connections in the pool until the MaxGreedyConns + // + // You can change this value while the HostClient is being used + // with HostClient.SetMaxGreedyConns(value) + MaxGreedyConns int + // Maximum number of connections which may be established to all hosts // listed in Addr. // @@ -1455,6 +1471,13 @@ func (e *timeoutError) Timeout() bool { // ErrTimeout is returned from timed out calls. var ErrTimeout = &timeoutError{} +// SetMaxGreedyConns sets up the maximum greedy number of connections which may be established to all hosts listed in Addr. +func (c *HostClient) SetMaxGreedyConns(newGreedyMaxConns int) { + c.connsLock.Lock() + c.MaxGreedyConns = newGreedyMaxConns + c.connsLock.Unlock() +} + // SetMaxConns sets up the maximum number of connections which may be established to all hosts listed in Addr. func (c *HostClient) SetMaxConns(newMaxConns int) { c.connsLock.Lock() @@ -1469,11 +1492,15 @@ func (c *HostClient) acquireConn(reqTimeout time.Duration, connectionClose bool) var n int c.connsLock.Lock() n = len(c.conns) - if n == 0 { + + shouldBeGreedy := c.MaxGreedyConns > 0 && c.connsCount < c.MaxGreedyConns + + if n == 0 || shouldBeGreedy { maxConns := c.MaxConns if maxConns <= 0 { maxConns = DefaultMaxConnsPerHost } + if c.connsCount < maxConns { c.connsCount++ createConn = true @@ -1482,7 +1509,9 @@ func (c *HostClient) acquireConn(reqTimeout time.Duration, connectionClose bool) c.connsCleanerRun = true } } - } else { + } + + if n > 0 && !createConn { switch c.ConnPoolStrategy { case LIFO: n-- @@ -2929,6 +2958,12 @@ func (t *transport) RoundTrip(hc *HostClient, req *Request, resp *Response) (ret if err == nil { err = bw.Flush() } + + + if req.ConnAcquiredCallback != nil { + req.ConnAcquiredCallback(conn) + } + hc.releaseWriter(bw) // Return ErrTimeout on any timeout. diff --git a/client_test.go b/client_test.go index 48429657cc..ec7b617551 100644 --- a/client_test.go +++ b/client_test.go @@ -2472,6 +2472,48 @@ func TestClientHTTPSConcurrent(t *testing.T) { wg.Wait() } +func TestClientMaxGreedyConnsPerHost(t *testing.T) { + t.Parallel() + + sHTTP := startEchoServer(t, "tcp", "127.0.0.1:") + defer sHTTP.Stop() + + testClientMaxGreedyConnsPerHost(t, sHTTP, 0, 10, 1) + testClientMaxGreedyConnsPerHost(t, sHTTP, 5, 10, 5) + testClientMaxGreedyConnsPerHost(t, sHTTP, 10, 10, 10) + testClientMaxGreedyConnsPerHost(t, sHTTP, 15, 10, 10) +} + +func testClientMaxGreedyConnsPerHost(t *testing.T, sHTTP *testEchoServer, maxGreedyConnsPerHost, maxConnsPerHost, expectConnsCount int) { + c := &Client{ + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + MaxGreedyConnsPerHost: maxGreedyConnsPerHost, + MaxConnsPerHost: maxConnsPerHost, + } + defer c.CloseIdleConnections() + + addr := "http://" + sHTTP.Addr() + for i := 0; i < c.MaxConnsPerHost; i++ { + testClientGet(t, c, addr, 1) + } + if len(c.m) != 1 { + t.Errorf("unexpected host map %d. Expecting 1", len(c.m)) + } + + var hc *HostClient + for _, v := range c.m { + hc = v + } + + cc := hc.ConnsCount() + + if cc != expectConnsCount { + t.Errorf("unexpected ConnsCount %d. Expecting %d", cc, expectConnsCount) + } +} + func TestClientManyServers(t *testing.T) { t.Parallel() diff --git a/http.go b/http.go index 5dd4e645f3..e7d422f8de 100644 --- a/http.go +++ b/http.go @@ -77,6 +77,12 @@ type Request struct { // By default redirect path values are normalized, i.e. // extra slashes are removed, special characters are encoded. DisableRedirectPathNormalizing bool + + // UsingProxy 是否使用代理 + UsingProxy bool + + // ConnAcquiredCallback 连接获取后回调,用于获取连接信息(例如 *tls.Conn 中的证书信息等) + ConnAcquiredCallback func(conn net.Conn) } // Response represents HTTP response. @@ -1582,7 +1588,12 @@ func (req *Request) Write(w *bufio.Writer) error { } else if !req.UseHostHeader { req.Header.SetHostBytes(host) } - req.Header.SetRequestURIBytes(uri.RequestURI()) + + if req.UsingProxy { + req.Header.SetRequestURIBytes(uri.FullURI()) + } else { + req.Header.SetRequestURIBytes(uri.RequestURI()) + } if len(uri.username) > 0 { // RequestHeader.SetBytesKV only uses RequestHeader.bufKV.key diff --git a/tcpdialer.go b/tcpdialer.go index e8430cb9c8..0fc3861701 100644 --- a/tcpdialer.go +++ b/tcpdialer.go @@ -133,10 +133,10 @@ type TCPDialer struct { // Changes made after the first Dial will not affect anything. Concurrency int - // LocalAddr is the local address to use when dialing an + // LocalAddrFunc is the local address to use when dialing an // address. // If nil, a local address is automatically chosen. - LocalAddr *net.TCPAddr + LocalAddrFunc func() *net.TCPAddr // This may be used to override DNS resolving policy, like this: // var dialer = &fasthttp.TCPDialer{ @@ -339,8 +339,8 @@ func (d *TCPDialer) tryDial( } dialer := net.Dialer{} - if d.LocalAddr != nil { - dialer.LocalAddr = d.LocalAddr + if d.LocalAddrFunc != nil { + dialer.LocalAddr = d.LocalAddrFunc() } ctx, cancelCtx := context.WithDeadline(context.Background(), deadline)