Skip to content

Commit 762d2b0

Browse files
jmwampleewust
andauthored
X-Forwarded-For Handling (#150)
* update the handling of x-forwared-for headers in the API registrar * clarify test cases * update doc comment * fix * Ignore Caddy's inserted X-Forwarded-For header * Skip last IP even if we aren't equal, if the current IP was local host (Caddy) * allow v6 localhost * Only skip one * only skip if we have more to skip to, also. Ugh, conditions * Cleaner way to do the same logic Co-authored-by: Eric Wustrow <[email protected]>
1 parent 7b9312d commit 762d2b0

File tree

2 files changed

+114
-30
lines changed

2 files changed

+114
-30
lines changed

pkg/apiregserver/apiregserver.go

Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,72 @@ type APIRegServer struct {
3333
metrics *metrics.Metrics
3434
}
3535

36-
// Get the first element of the X-Forwarded-For header if it is available, this
37-
// will be the clients address if intermediate proxies follow X-Forwarded-For
38-
// specification (as seen here: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For).
36+
var clientIPHeaderNames = []string{
37+
"X-Forwarded-For",
38+
// "X-Client-IP",
39+
// "True-Client-IP",
40+
}
41+
42+
// getRemoteAddr gets the last entry of the last instance of the X-Forwarded-For
43+
// header if it is available, this is our best guess at the clients address if
44+
// intermediate proxies follow X-Forwarded-For specification (as seen here:
45+
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For).
3946
// Otherwise return the remote address specified in the request.
4047
//
41-
// In the future this may need to handle True-Client-IP headers.
42-
func getRemoteAddr(r *http.Request) string {
43-
if r.Header.Get("X-Forwarded-For") != "" {
44-
addrList := r.Header.Get("X-Forwarded-For")
45-
return strings.Trim(strings.Split(addrList, ",")[0], " \t")
48+
// In the future this may need to handle True-Client-IP headers, but in general
49+
// none of these are to be trusted -
50+
// https://adam-p.ca/blog/2022/03/x-forwarded-for/. If those are enabled in
51+
// clientIPHeaderNames ensure that the ordering checks them in order of most to
52+
// least trusted.
53+
func getRemoteAddr(r *http.Request) net.IP {
54+
55+
// Default to the clients remote address if no identifying header is provided
56+
ip := parseIP(r.RemoteAddr)
57+
58+
// When there are multiple header names in clientIPHeaderNames,
59+
// the first valid match is preferred. clientIPHeaderNames should be
60+
// configured to use header names that are always provided by the CDN(s) and
61+
// not header names that may be passed through from clients.
62+
for _, header := range clientIPHeaderNames {
63+
64+
// In the case where there are multiple headers,
65+
// request.Header.Get returns the first header, but we want the
66+
// last header; so use request.Header.Values and select the last
67+
// value. As per RFC 2616 section 4.2, a proxy must not change
68+
// the order of field values, which implies that it should append
69+
// values to the last header.
70+
values := r.Header.Values(header)
71+
if len(values) > 0 {
72+
value := values[len(values)-1]
73+
74+
// Some headers, such as X-Forwarded-For, are a comma-separated
75+
// list of IPs (each proxy in a chain). Select the last IP.
76+
IPs := strings.Split(value, ",")
77+
IP := IPs[len(IPs)-1]
78+
79+
// Caddy appends an X-Forward-For from the client (potentially CDN)
80+
// We configure Caddy to trust the domain-fronted proxies,
81+
// which will give us a list of real_client_ip, cdn_ip.
82+
// In that case, r.RemoteAddr will be localhost, and we want
83+
// to skip the CDN IP in the list
84+
if len(IPs) > 1 &&
85+
(ip.Equal(net.ParseIP("127.0.0.1")) ||
86+
ip.Equal(net.ParseIP("::1"))) {
87+
IP = IPs[len(IPs)-2]
88+
}
89+
90+
// Remove optional whitespace surrounding the commas.
91+
IP = strings.TrimSpace(IP)
92+
93+
headerIP := net.ParseIP(IP)
94+
if headerIP != nil {
95+
ip = headerIP
96+
break
97+
}
98+
}
4699
}
47-
return r.RemoteAddr
100+
101+
return ip
48102
}
49103

50104
func (s *APIRegServer) getC2SFromReq(w http.ResponseWriter, r *http.Request) (*pb.C2SWrapper, error) {
@@ -81,11 +135,15 @@ func (s *APIRegServer) getC2SFromReq(w http.ResponseWriter, r *http.Request) (*p
81135
func (s *APIRegServer) register(w http.ResponseWriter, r *http.Request) {
82136
s.metrics.Add("api_requests_total", 1)
83137

84-
requestIP := getRemoteAddr(r)
138+
clientAddr := getRemoteAddr(r)
139+
if clientAddr == nil {
140+
w.WriteHeader(http.StatusBadRequest)
141+
return
142+
}
85143

86144
logFields := log.Fields{"http_method": r.Method, "content_length": r.ContentLength, "registration_type": "unidirectional"}
87145
if s.logClientIP {
88-
logFields["ip_address"] = requestIP
146+
logFields["ip_address"] = clientAddr.String()
89147
}
90148
reqLogger := s.logger.WithFields(logFields)
91149

@@ -99,7 +157,6 @@ func (s *APIRegServer) register(w http.ResponseWriter, r *http.Request) {
99157

100158
reqLogger = reqLogger.WithField("reg_id", hex.EncodeToString(payload.GetSharedSecret()))
101159

102-
clientAddr := parseIP(requestIP)
103160
var clientAddrBytes = make([]byte, 16)
104161
if clientAddr != nil {
105162
clientAddrBytes = []byte(clientAddr.To16())
@@ -122,11 +179,16 @@ func (s *APIRegServer) register(w http.ResponseWriter, r *http.Request) {
122179

123180
func (s *APIRegServer) registerBidirectional(w http.ResponseWriter, r *http.Request) {
124181
s.metrics.Add("bdapi_requests_total", 1)
125-
requestIP := getRemoteAddr(r)
182+
183+
clientAddr := getRemoteAddr(r)
184+
if clientAddr == nil {
185+
w.WriteHeader(http.StatusBadRequest)
186+
return
187+
}
126188

127189
logFields := log.Fields{"http_method": r.Method, "content_length": r.ContentLength, "registration_type": "bidirectional"}
128190
if s.logClientIP {
129-
logFields["ip_address"] = requestIP
191+
logFields["ip_address"] = clientAddr.String()
130192
}
131193
reqLogger := s.logger.WithFields(logFields)
132194

@@ -139,7 +201,6 @@ func (s *APIRegServer) registerBidirectional(w http.ResponseWriter, r *http.Requ
139201

140202
reqLogger = reqLogger.WithField("reg_id", hex.EncodeToString(payload.GetSharedSecret()))
141203

142-
clientAddr := parseIP(requestIP)
143204
var clientAddrBytes = make([]byte, 16)
144205
if clientAddr != nil {
145206
clientAddrBytes = []byte(clientAddr.To16())
@@ -214,7 +275,7 @@ func (s *APIRegServer) compareClientConfGen(genNum uint32) *pb.ClientConf {
214275

215276
// parseIP attempts to parse the IP address of a request from string format wether
216277
// it has a port attached to it or not. Returns nil if parse fails.
217-
func parseIP(addrPort string) *net.IP {
278+
func parseIP(addrPort string) net.IP {
218279

219280
// by default format from r.RemoteAddr is host:port
220281
host, _, err := net.SplitHostPort(addrPort)
@@ -224,13 +285,12 @@ func parseIP(addrPort string) *net.IP {
224285
if addr == nil {
225286
return nil
226287
}
227-
return &addr
288+
return addr
228289
}
229290

230291
addr := net.ParseIP(host)
231292

232-
return &addr
233-
293+
return addr
234294
}
235295

236296
func (s *APIRegServer) NewClientConf(c *pb.ClientConf) {

pkg/apiregserver/apiregserver_test.go

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -218,22 +218,46 @@ func BenchmarkRegistration(b *testing.B) {
218218
}
219219
}
220220

221+
var xff = "X-Forwarded-For"
222+
221223
func TestAPIGetClientAddr(t *testing.T) {
222224

223225
req, err := http.NewRequest("GET", "http://example.com", nil)
224226
require.Nil(t, err)
225227

226-
req.RemoteAddr = "10.0.0.0"
227-
require.Equal(t, "10.0.0.0", getRemoteAddr(req))
228-
229-
req.Header.Add("X-Forwarded-For", "192.168.1.1")
230-
require.Equal(t, "192.168.1.1", getRemoteAddr(req))
231-
232-
req.Header.Set("X-Forwarded-For", "127.0.0.1, 192.168.0.0")
233-
require.Equal(t, "127.0.0.1", getRemoteAddr(req))
234-
235-
req.Header.Set("X-Forwarded-For", "127.0.0.1,192.168.0.0")
236-
require.Equal(t, "127.0.0.1", getRemoteAddr(req))
228+
// If only the RemoteAddress is available we should use that.
229+
req.RemoteAddr = "10.0.0.0:80"
230+
ip := getRemoteAddr(req)
231+
require.Equal(t, net.ParseIP("10.0.0.0"), ip, "expected %s got %s", "10.0.0.0", ip)
232+
233+
// if an XFF address is available we should use that if it parses properly as an IP
234+
req.Header.Add(xff, "192.168.1.1")
235+
ip = getRemoteAddr(req)
236+
require.Equal(t, net.ParseIP("192.168.1.1"), ip, "expected %s got %s", "192.168.1.1", ip)
237+
238+
// if an XFF address is available, but does not parse as a valid IP we should return the
239+
// remote address.
240+
req.Header.Set(xff, "127.example.com")
241+
ip = getRemoteAddr(req)
242+
require.Equal(t, net.ParseIP("10.0.0.0"), ip, "expected %s got %s", "10.0.0.0", ip)
243+
244+
// If more than one IP is provided (i.e. multiple proxy hops) take the last one
245+
req.Header.Set(xff, "127.0.0.1, 192.168.0.0")
246+
ip = getRemoteAddr(req)
247+
require.Equal(t, net.ParseIP("192.168.0.0"), ip, "expected %s got %s", "192.168.0.0", ip)
248+
249+
req.Header.Set(xff, "127.0.0.1,192.168.0.0")
250+
ip = getRemoteAddr(req)
251+
require.Equal(t, net.ParseIP("192.168.0.0"), ip, "expected %s got %s", "192.168.0.0", ip)
252+
253+
// Add a second header X Ignore header with different value, we want to use the
254+
// last value from the last instance instance, i.e. the previous
255+
req.Header.Add(xff, "1.1.1.1,8.8.8.8")
256+
// for _, v := range req.Header.Values(xff) {
257+
// t.Log(xff, v)
258+
// }
259+
ip = getRemoteAddr(req)
260+
require.Equal(t, net.ParseIP("8.8.8.8"), ip, "expected %s got %s", "8.8.8.8", ip)
237261
}
238262

239263
func TestCorrectUnidirectionalAPI(t *testing.T) {

0 commit comments

Comments
 (0)