Skip to content
37 changes: 18 additions & 19 deletions lib/http/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,9 @@ func (c *Client) Do(req *Request) (resp *Response, err error) {
for {
// For all but the first request, create the next
// request hop and replace req.
loc := req.URL.String()
if len(reqs) > 0 {
loc := resp.Header.Get("Location")
loc = resp.Header.Get("Location")
if loc == "" {
return nil, uerr(fmt.Errorf("%d response missing Location header", resp.StatusCode))
}
Expand Down Expand Up @@ -571,14 +572,6 @@ func (c *Client) Do(req *Request) (resp *Response, err error) {
if ref := refererForURL(reqs[len(reqs)-1].URL, req.URL); ref != "" {
req.Header.Set("Referer", ref)
}
err = c.checkRedirect(req, resp, reqs)

// Sentinel error to let users select the
// previous response, without closing its
// body. See Issue 10069.
if err == ErrUseLastResponse {
return resp, nil
}

// Close the previous response's body. But
// read at least some of the body so if it's
Expand All @@ -590,16 +583,6 @@ func (c *Client) Do(req *Request) (resp *Response, err error) {
io.CopyN(ioutil.Discard, resp.Body, maxBodySlurpSize)
}
resp.Body.Close()

if err != nil {
// Special case for Go 1 compatibility: return both the response
// and an error if the CheckRedirect function failed.
// See https://golang.org/issue/3795
// The resp.Body has already been closed.
ue := uerr(err)
ue.(*url.Error).URL = loc
return resp, ue
}
}

reqs = append(reqs, req)
Expand All @@ -614,6 +597,22 @@ func (c *Client) Do(req *Request) (resp *Response, err error) {
}
return nil, uerr(err)
}
err = c.checkRedirect(req, resp, reqs)

// Sentinel error to let users select the
// previous response, without closing its
// body. See Issue 10069.
if err == ErrUseLastResponse {
return resp, nil
}
if err != nil {
// Special case for Go 1 compatibility: return both the response
// and an error if the CheckRedirect function failed.
// See https://golang.org/issue/3795
ue := uerr(err)
ue.(*url.Error).URL = loc
return resp, err
}

var shouldRedirect bool
redirectMethod, shouldRedirect, includeBody = redirectBehavior(req.Method, resp, reqs[0])
Expand Down
14 changes: 10 additions & 4 deletions modules/http/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ var (
// ErrTooManyRedirects is returned when the number of HTTP redirects exceeds
// MaxRedirects.
ErrTooManyRedirects = errors.New("too many redirects")
ErrDoNotRedirect = errors.New("no redirects configured")
)

// Flags holds the command-line configuration for the HTTP scan module.
Expand Down Expand Up @@ -326,6 +327,13 @@ func redirectsToLocalhost(host string) bool {
// the redirectToLocalhost and MaxRedirects config
func (scan *scan) getCheckRedirect() func(*http.Request, *http.Response, []*http.Request) error {
return func(req *http.Request, res *http.Response, via []*http.Request) error {
if scan.scanner.config.MaxRedirects == 0 {
return ErrDoNotRedirect
}
//len-1 because otherwise we'll return a failure on 1 redirect when we specify only 1 redirect. I.e. we are 0
if len(via)-1 > scan.scanner.config.MaxRedirects {
return ErrTooManyRedirects
}
if !scan.scanner.config.FollowLocalhostRedirects && redirectsToLocalhost(req.URL.Hostname()) {
return ErrRedirLocalhost
}
Expand Down Expand Up @@ -353,10 +361,6 @@ func (scan *scan) getCheckRedirect() func(*http.Request, *http.Response, []*http
}
}

if len(via) > scan.scanner.config.MaxRedirects {
return ErrTooManyRedirects
}

return nil
}
}
Expand Down Expand Up @@ -496,6 +500,8 @@ func (scan *scan) Grab() *zgrab2.ScanError {
}
if err != nil {
switch err {
case ErrDoNotRedirect:
break
case ErrRedirLocalhost:
break
case ErrTooManyRedirects:
Expand Down
17 changes: 12 additions & 5 deletions modules/ipp/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"encoding/binary"
"errors"
"fmt"

"io"
"mime"
"net"
Expand Down Expand Up @@ -39,6 +38,9 @@ var (
// MaxRedirects.
ErrTooManyRedirects = errors.New("too many redirects")

// ErrDoNotRedirect is returned when the scanner is configured not to follow redirects
ErrDoNotRedirect = errors.New("no redirects configured")

// TODO: Explain this error
ErrVersionNotSupported = errors.New("IPP version not supported")

Expand Down Expand Up @@ -510,6 +512,8 @@ func sendIPPRequest(scan *scan, body *bytes.Buffer) (*http.Response, *zgrab2.Sca
break
case ErrTooManyRedirects:
return resp, zgrab2.NewScanError(zgrab2.SCAN_APPLICATION_ERROR, err)
case ErrDoNotRedirect:
break
default:
return resp, zgrab2.DetectScanError(err)
}
Expand Down Expand Up @@ -648,6 +652,13 @@ func redirectsToLocalhost(host string) bool {
// Taken from zgrab/zlib/grabber.go -- get a CheckRedirect callback that uses redirectToLocalhost and MaxRedirects config
func (scan *scan) getCheckRedirect(scanner *Scanner) func(*http.Request, *http.Response, []*http.Request) error {
return func(req *http.Request, res *http.Response, via []*http.Request) error {
if scanner.config.MaxRedirects == 0 {
return ErrDoNotRedirect
}
//len-1 because otherwise we'll return a failure on 1 redirect when we specify only 1 redirect. I.e. we are 0
if len(via)-1 > scanner.config.MaxRedirects {
return ErrTooManyRedirects
}
if !scanner.config.FollowLocalhostRedirects && redirectsToLocalhost(req.URL.Hostname()) {
return ErrRedirLocalhost
}
Expand All @@ -656,10 +667,6 @@ func (scan *scan) getCheckRedirect(scanner *Scanner) func(*http.Request, *http.R
return zgrab2.NewScanError(zgrab2.SCAN_UNKNOWN_ERROR, fmt.Errorf("could not store body: %v", err))
}

if len(via) > scanner.config.MaxRedirects {
return ErrTooManyRedirects
}

return nil
}
}
Expand Down
Loading