Skip to content

Commit 5d377de

Browse files
stainless-app[bot]yjp20
authored andcommitted
feat: redact secrets from other authentication headers when using debug option
1 parent 103fbed commit 5d377de

File tree

3 files changed

+162
-25
lines changed

3 files changed

+162
-25
lines changed

internal/debugmiddleware/debug_middleware.go

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package debugmiddleware
22

33
import (
4+
"log"
45
"net/http"
56
"net/http/httputil"
67
"strings"
8+
"sync"
79
)
810

911
// For the time being these type definitions are duplicated here so that we can
@@ -15,14 +17,27 @@ type (
1517

1618
const redactedPlaceholder = "<REDACTED>"
1719

18-
// DebugMiddleware returns a middleware that logs HTTP requests and responses.
19-
//
20-
// logWriter is log.Default() under most circumstances, but made low level so we
21-
// can more easily inject a buffer to check in tests.
22-
func DebugMiddleware(logger interface{ Printf(string, ...any) }) Middleware {
20+
// Headers known to contain sensitive information like an API key.
21+
var sensitiveHeaders = []string{}
22+
23+
// RequestLogger is a middleware that logs HTTP requests and responses.
24+
type RequestLogger struct {
25+
logger interface{ Printf(string, ...any) } // field for testability; usually log.Default()
26+
sensitiveHeaders []string // field for testability; usually sensitiveHeaders
27+
}
28+
29+
// NewRequestLogger returns a new RequestLogger instance with default options.
30+
func NewRequestLogger() *RequestLogger {
31+
return &RequestLogger{
32+
logger: log.Default(),
33+
sensitiveHeaders: sensitiveHeaders,
34+
}
35+
}
36+
37+
func (m *RequestLogger) Middleware() Middleware {
2338
return func(req *http.Request, mn MiddlewareNext) (*http.Response, error) {
24-
if reqBytes, err := httputil.DumpRequest(redactRequest(req), true); err == nil {
25-
logger.Printf("Request Content:\n%s\n", reqBytes)
39+
if reqBytes, err := httputil.DumpRequest(m.redactRequest(req), true); err == nil {
40+
m.logger.Printf("Request Content:\n%s\n", reqBytes)
2641
}
2742

2843
resp, err := mn(req)
@@ -31,7 +46,7 @@ func DebugMiddleware(logger interface{ Printf(string, ...any) }) Middleware {
3146
}
3247

3348
if respBytes, err := httputil.DumpResponse(resp, true); err == nil {
34-
logger.Printf("Response Content:\n%s\n", respBytes)
49+
m.logger.Printf("Response Content:\n%s\n", respBytes)
3550
}
3651

3752
return resp, err
@@ -42,17 +57,40 @@ func DebugMiddleware(logger interface{ Printf(string, ...any) }) Middleware {
4257
// purposes. If redaction is necessary, the request is cloned before mutating
4358
// the original and that clone is returned. As a small optimization, the
4459
// original is request is returned unchanged if no redaction is necessary.
45-
func redactRequest(req *http.Request) *http.Request {
46-
if auth := req.Header.Get("Authorization"); auth != "" {
60+
func (m *RequestLogger) redactRequest(req *http.Request) *http.Request {
61+
cloneReq := sync.OnceFunc(func() {
4762
req = req.Clone(req.Context())
63+
})
64+
65+
// Notably, the clauses below are written so they can redact multiple
66+
// headers of the same name if necessary.
67+
if values := req.Header.Values("Authorization"); len(values) > 0 {
68+
cloneReq()
69+
req.Header.Del("Authorization")
70+
71+
for _, value := range values {
72+
// In case we're using something like a bearer token (e.g. `Bearer
73+
// <my_token>`), keep the `Bearer` part for more debugging
74+
// information.
75+
if authKind, _, ok := strings.Cut(value, " "); ok {
76+
req.Header.Add("Authorization", authKind+" "+redactedPlaceholder)
77+
} else {
78+
req.Header.Add("Authorization", redactedPlaceholder)
79+
}
80+
}
81+
}
82+
83+
for _, header := range m.sensitiveHeaders {
84+
values := req.Header.Values(header)
85+
if len(values) == 0 {
86+
continue
87+
}
88+
89+
cloneReq()
90+
req.Header.Del(header)
4891

49-
// In case we're using something like a bearer token (e.g. `Bearer
50-
// <my_token>`), keep the `Bearer` part for more debugging
51-
// information.
52-
if authKind, _, ok := strings.Cut(auth, " "); ok {
53-
req.Header.Set("Authorization", authKind+" "+redactedPlaceholder)
54-
} else {
55-
req.Header.Set("Authorization", redactedPlaceholder)
92+
for range values {
93+
req.Header.Add(header, redactedPlaceholder)
5694
}
5795
}
5896

internal/debugmiddleware/debug_middleware_test.go

Lines changed: 106 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,21 @@ import (
55
"log"
66
"net/http"
77
"net/http/httptest"
8+
"reflect"
89
"strings"
910
"testing"
1011
)
1112

1213
func TestDebugMiddleware(t *testing.T) {
1314
t.Parallel()
1415

15-
setup := func() (Middleware, *bytes.Buffer) {
16-
var logBuf bytes.Buffer
17-
return DebugMiddleware(log.New(&logBuf, "", 0)), &logBuf
16+
setup := func() (*RequestLogger, *bytes.Buffer) {
17+
var (
18+
logBuf bytes.Buffer
19+
middleware = NewRequestLogger()
20+
)
21+
middleware.logger = log.New(&logBuf, "", 0)
22+
return middleware, &logBuf
1823
}
1924

2025
t.Run("DoesNotRedactMostHeaders", func(t *testing.T) {
@@ -28,7 +33,7 @@ func TestDebugMiddleware(t *testing.T) {
2833
req.Header.Set("User-Agent", stainlessUserAgent)
2934

3035
var nextMiddlewareRan bool
31-
middleware(req, func(req *http.Request) (*http.Response, error) {
36+
middleware.Middleware()(req, func(req *http.Request) (*http.Response, error) {
3237
nextMiddlewareRan = true
3338

3439
// The request sent down through middleware shouldn't be mutated.
@@ -59,7 +64,7 @@ func TestDebugMiddleware(t *testing.T) {
5964
req.Header.Set("Authorization", secretToken)
6065

6166
var nextMiddlewareRan bool
62-
middleware(req, func(req *http.Request) (*http.Response, error) {
67+
middleware.Middleware()(req, func(req *http.Request) (*http.Response, error) {
6368
nextMiddlewareRan = true
6469

6570
// The request sent down through middleware shouldn't be mutated.
@@ -88,7 +93,7 @@ func TestDebugMiddleware(t *testing.T) {
8893
req.Header.Set("Authorization", "Bearer "+secretToken)
8994

9095
var nextMiddlewareRan bool
91-
middleware(req, func(req *http.Request) (*http.Response, error) {
96+
middleware.Middleware()(req, func(req *http.Request) (*http.Response, error) {
9297
nextMiddlewareRan = true
9398

9499
return &http.Response{}, nil
@@ -102,4 +107,99 @@ func TestDebugMiddleware(t *testing.T) {
102107
t.Error("expected authorization header to be redacted")
103108
}
104109
})
110+
111+
t.Run("RedactsMultipleAuthorizationHeaders", func(t *testing.T) {
112+
t.Parallel()
113+
114+
middleware, logBuf := setup()
115+
116+
req := httptest.NewRequest("GET", "https://example.com", nil)
117+
req.Header.Add("Authorization", secretToken+"1")
118+
req.Header.Add("Authorization", secretToken+"2")
119+
120+
var nextMiddlewareRan bool
121+
middleware.Middleware()(req, func(req *http.Request) (*http.Response, error) {
122+
nextMiddlewareRan = true
123+
124+
// The request sent down through middleware shouldn't be mutated.
125+
if !reflect.DeepEqual(req.Header.Values("Authorization"), []string{secretToken + "1", secretToken + "2"}) {
126+
t.Errorf("expected original request to be unmodified")
127+
}
128+
129+
return &http.Response{}, nil
130+
})
131+
132+
if !nextMiddlewareRan {
133+
t.Error("expected next middleware to have been run")
134+
}
135+
136+
if strings.Count(logBuf.String(), "Authorization: "+redactedPlaceholder) != 2 {
137+
t.Error("expected exactly two redacted placeholders in authorization headers")
138+
}
139+
})
140+
141+
const customAPIKeyHeader = "X-My-Api-Key"
142+
143+
t.Run("RedactsSensitiveHeaders", func(t *testing.T) {
144+
t.Parallel()
145+
146+
middleware, logBuf := setup()
147+
148+
middleware.sensitiveHeaders = []string{customAPIKeyHeader}
149+
150+
req := httptest.NewRequest("GET", "https://example.com", nil)
151+
req.Header.Set(customAPIKeyHeader, secretToken)
152+
153+
var nextMiddlewareRan bool
154+
middleware.Middleware()(req, func(req *http.Request) (*http.Response, error) {
155+
nextMiddlewareRan = true
156+
157+
// The request sent down through middleware shouldn't be mutated.
158+
if req.Header.Get(customAPIKeyHeader) != secretToken {
159+
t.Error("expected original request to be unmodified")
160+
}
161+
162+
return &http.Response{}, nil
163+
})
164+
165+
if !nextMiddlewareRan {
166+
t.Error("expected next middleware to have been run")
167+
}
168+
169+
if !strings.Contains(logBuf.String(), customAPIKeyHeader+": "+redactedPlaceholder) {
170+
t.Errorf("expected %s header to be redacted", customAPIKeyHeader)
171+
}
172+
})
173+
174+
t.Run("RedactsMultipleSensitiveHeaders", func(t *testing.T) {
175+
t.Parallel()
176+
177+
middleware, logBuf := setup()
178+
179+
middleware.sensitiveHeaders = []string{customAPIKeyHeader}
180+
181+
req := httptest.NewRequest("GET", "https://example.com", nil)
182+
req.Header.Add(customAPIKeyHeader, secretToken+"1")
183+
req.Header.Add(customAPIKeyHeader, secretToken+"2")
184+
185+
var nextMiddlewareRan bool
186+
middleware.Middleware()(req, func(req *http.Request) (*http.Response, error) {
187+
nextMiddlewareRan = true
188+
189+
// The request sent down through middleware shouldn't be mutated.
190+
if !reflect.DeepEqual(req.Header.Values(customAPIKeyHeader), []string{secretToken + "1", secretToken + "2"}) {
191+
t.Error("expected original request to be unmodified")
192+
}
193+
194+
return &http.Response{}, nil
195+
})
196+
197+
if !nextMiddlewareRan {
198+
t.Error("expected next middleware to have been run")
199+
}
200+
201+
if strings.Count(logBuf.String(), customAPIKeyHeader+": "+redactedPlaceholder) != 2 {
202+
t.Errorf("expected %s header to be redacted", customAPIKeyHeader)
203+
}
204+
})
105205
}

pkg/cmd/flagoptions.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"bytes"
55
"encoding/json"
66
"io"
7-
"log"
87
"mime/multipart"
98
"os"
109

@@ -32,7 +31,7 @@ func flagOptions(
3231
) ([]option.RequestOption, error) {
3332
var options []option.RequestOption
3433
if cmd.Bool("debug") {
35-
options = append(options, option.WithMiddleware(debugmiddleware.DebugMiddleware(log.Default())))
34+
options = append(options, option.WithMiddleware(debugmiddleware.NewRequestLogger().Middleware()))
3635
}
3736

3837
queries := make(map[string]any)

0 commit comments

Comments
 (0)