@@ -33,18 +33,72 @@ type APIRegServer struct {
3333 metrics * metrics.Metrics
3434}
3535
36- // Get the first element of the X-Forwarded-For header if it is available, this
37- // will be the clients address if intermediate proxies follow X-Forwarded-For
38- // specification (as seen here: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For).
36+ var clientIPHeaderNames = []string {
37+ "X-Forwarded-For" ,
38+ // "X-Client-IP",
39+ // "True-Client-IP",
40+ }
41+
42+ // getRemoteAddr gets the last entry of the last instance of the X-Forwarded-For
43+ // header if it is available, this is our best guess at the clients address if
44+ // intermediate proxies follow X-Forwarded-For specification (as seen here:
45+ // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For).
3946// Otherwise return the remote address specified in the request.
4047//
41- // In the future this may need to handle True-Client-IP headers.
42- func getRemoteAddr (r * http.Request ) string {
43- if r .Header .Get ("X-Forwarded-For" ) != "" {
44- addrList := r .Header .Get ("X-Forwarded-For" )
45- return strings .Trim (strings .Split (addrList , "," )[0 ], " \t " )
48+ // In the future this may need to handle True-Client-IP headers, but in general
49+ // none of these are to be trusted -
50+ // https://adam-p.ca/blog/2022/03/x-forwarded-for/. If those are enabled in
51+ // clientIPHeaderNames ensure that the ordering checks them in order of most to
52+ // least trusted.
53+ func getRemoteAddr (r * http.Request ) net.IP {
54+
55+ // Default to the clients remote address if no identifying header is provided
56+ ip := parseIP (r .RemoteAddr )
57+
58+ // When there are multiple header names in clientIPHeaderNames,
59+ // the first valid match is preferred. clientIPHeaderNames should be
60+ // configured to use header names that are always provided by the CDN(s) and
61+ // not header names that may be passed through from clients.
62+ for _ , header := range clientIPHeaderNames {
63+
64+ // In the case where there are multiple headers,
65+ // request.Header.Get returns the first header, but we want the
66+ // last header; so use request.Header.Values and select the last
67+ // value. As per RFC 2616 section 4.2, a proxy must not change
68+ // the order of field values, which implies that it should append
69+ // values to the last header.
70+ values := r .Header .Values (header )
71+ if len (values ) > 0 {
72+ value := values [len (values )- 1 ]
73+
74+ // Some headers, such as X-Forwarded-For, are a comma-separated
75+ // list of IPs (each proxy in a chain). Select the last IP.
76+ IPs := strings .Split (value , "," )
77+ IP := IPs [len (IPs )- 1 ]
78+
79+ // Caddy appends an X-Forward-For from the client (potentially CDN)
80+ // We configure Caddy to trust the domain-fronted proxies,
81+ // which will give us a list of real_client_ip, cdn_ip.
82+ // In that case, r.RemoteAddr will be localhost, and we want
83+ // to skip the CDN IP in the list
84+ if len (IPs ) > 1 &&
85+ (ip .Equal (net .ParseIP ("127.0.0.1" )) ||
86+ ip .Equal (net .ParseIP ("::1" ))) {
87+ IP = IPs [len (IPs )- 2 ]
88+ }
89+
90+ // Remove optional whitespace surrounding the commas.
91+ IP = strings .TrimSpace (IP )
92+
93+ headerIP := net .ParseIP (IP )
94+ if headerIP != nil {
95+ ip = headerIP
96+ break
97+ }
98+ }
4699 }
47- return r .RemoteAddr
100+
101+ return ip
48102}
49103
50104func (s * APIRegServer ) getC2SFromReq (w http.ResponseWriter , r * http.Request ) (* pb.C2SWrapper , error ) {
@@ -81,11 +135,15 @@ func (s *APIRegServer) getC2SFromReq(w http.ResponseWriter, r *http.Request) (*p
81135func (s * APIRegServer ) register (w http.ResponseWriter , r * http.Request ) {
82136 s .metrics .Add ("api_requests_total" , 1 )
83137
84- requestIP := getRemoteAddr (r )
138+ clientAddr := getRemoteAddr (r )
139+ if clientAddr == nil {
140+ w .WriteHeader (http .StatusBadRequest )
141+ return
142+ }
85143
86144 logFields := log.Fields {"http_method" : r .Method , "content_length" : r .ContentLength , "registration_type" : "unidirectional" }
87145 if s .logClientIP {
88- logFields ["ip_address" ] = requestIP
146+ logFields ["ip_address" ] = clientAddr . String ()
89147 }
90148 reqLogger := s .logger .WithFields (logFields )
91149
@@ -99,7 +157,6 @@ func (s *APIRegServer) register(w http.ResponseWriter, r *http.Request) {
99157
100158 reqLogger = reqLogger .WithField ("reg_id" , hex .EncodeToString (payload .GetSharedSecret ()))
101159
102- clientAddr := parseIP (requestIP )
103160 var clientAddrBytes = make ([]byte , 16 )
104161 if clientAddr != nil {
105162 clientAddrBytes = []byte (clientAddr .To16 ())
@@ -122,11 +179,16 @@ func (s *APIRegServer) register(w http.ResponseWriter, r *http.Request) {
122179
123180func (s * APIRegServer ) registerBidirectional (w http.ResponseWriter , r * http.Request ) {
124181 s .metrics .Add ("bdapi_requests_total" , 1 )
125- requestIP := getRemoteAddr (r )
182+
183+ clientAddr := getRemoteAddr (r )
184+ if clientAddr == nil {
185+ w .WriteHeader (http .StatusBadRequest )
186+ return
187+ }
126188
127189 logFields := log.Fields {"http_method" : r .Method , "content_length" : r .ContentLength , "registration_type" : "bidirectional" }
128190 if s .logClientIP {
129- logFields ["ip_address" ] = requestIP
191+ logFields ["ip_address" ] = clientAddr . String ()
130192 }
131193 reqLogger := s .logger .WithFields (logFields )
132194
@@ -139,7 +201,6 @@ func (s *APIRegServer) registerBidirectional(w http.ResponseWriter, r *http.Requ
139201
140202 reqLogger = reqLogger .WithField ("reg_id" , hex .EncodeToString (payload .GetSharedSecret ()))
141203
142- clientAddr := parseIP (requestIP )
143204 var clientAddrBytes = make ([]byte , 16 )
144205 if clientAddr != nil {
145206 clientAddrBytes = []byte (clientAddr .To16 ())
@@ -214,7 +275,7 @@ func (s *APIRegServer) compareClientConfGen(genNum uint32) *pb.ClientConf {
214275
215276// parseIP attempts to parse the IP address of a request from string format wether
216277// it has a port attached to it or not. Returns nil if parse fails.
217- func parseIP (addrPort string ) * net.IP {
278+ func parseIP (addrPort string ) net.IP {
218279
219280 // by default format from r.RemoteAddr is host:port
220281 host , _ , err := net .SplitHostPort (addrPort )
@@ -224,13 +285,12 @@ func parseIP(addrPort string) *net.IP {
224285 if addr == nil {
225286 return nil
226287 }
227- return & addr
288+ return addr
228289 }
229290
230291 addr := net .ParseIP (host )
231292
232- return & addr
233-
293+ return addr
234294}
235295
236296func (s * APIRegServer ) NewClientConf (c * pb.ClientConf ) {
0 commit comments