Skip to content

Commit e007fd1

Browse files
authored
Retry proxy on connection refused error (#368)
1 parent 1db7517 commit e007fd1

File tree

1 file changed

+41
-1
lines changed

1 file changed

+41
-1
lines changed

http/proxy_server.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ package http
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"io"
78
"log"
89
"net"
910
"net/http"
1011
"regexp"
12+
"syscall"
1113
"time"
1214

1315
"github.com/superfly/litefs"
@@ -61,6 +63,8 @@ type ProxyServer struct {
6163

6264
// Time before cookie expires on client.
6365
CookieExpiry time.Duration
66+
67+
HTTPTransport *http.Transport
6468
}
6569

6670
// NewProxyServer returns a new instance of ProxyServer.
@@ -79,6 +83,19 @@ func NewProxyServer(store *litefs.Store) *ProxyServer {
7983
Handler: http.HandlerFunc(s.serveHTTP),
8084
}
8185

86+
s.HTTPTransport = &http.Transport{
87+
Proxy: http.ProxyFromEnvironment,
88+
DialContext: dialContextWithRetry(&net.Dialer{
89+
Timeout: 30 * time.Second,
90+
KeepAlive: 30 * time.Second,
91+
}),
92+
ForceAttemptHTTP2: true,
93+
MaxIdleConns: 100,
94+
IdleConnTimeout: 90 * time.Second,
95+
TLSHandshakeTimeout: 10 * time.Second,
96+
ExpectContinueTimeout: 1 * time.Second,
97+
}
98+
8299
return s
83100
}
84101

@@ -238,7 +255,7 @@ func (s *ProxyServer) proxyToTarget(w http.ResponseWriter, r *http.Request, pass
238255
r.URL.Scheme = "http"
239256
r.URL.Host = s.Target
240257

241-
resp, err := http.DefaultTransport.RoundTrip(r)
258+
resp, err := s.HTTPTransport.RoundTrip(r)
242259
if err != nil {
243260
http.Error(w, "Proxy error: "+err.Error(), http.StatusBadGateway)
244261
return
@@ -295,3 +312,26 @@ func (s *ProxyServer) logf(format string, v ...any) {
295312
log.Printf(format, v...)
296313
}
297314
}
315+
316+
// dialContextWithRetry returns a function that will retry
317+
func dialContextWithRetry(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
318+
return func(ctx context.Context, network, address string) (net.Conn, error) {
319+
timeout := time.NewTimer(dialer.Timeout)
320+
defer timeout.Stop()
321+
322+
for {
323+
conn, err := dialer.DialContext(ctx, network, address)
324+
if !errors.Is(err, syscall.ECONNREFUSED) {
325+
return conn, err
326+
}
327+
328+
select {
329+
case <-ctx.Done():
330+
return nil, context.Cause(ctx)
331+
case <-timeout.C:
332+
return nil, err
333+
case <-time.After(100 * time.Millisecond):
334+
}
335+
}
336+
}
337+
}

0 commit comments

Comments
 (0)