Skip to content

Commit a7ceb2d

Browse files
authored
Merge pull request #2533 from smallstep/mariano/transport-decorator
Add HTTP transport decorator
2 parents f088be2 + 63231db commit a7ceb2d

File tree

4 files changed

+94
-21
lines changed

4 files changed

+94
-21
lines changed

ca/client.go

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,13 @@ type RetryFunc func(code int) bool
153153
// ClientOption is the type of options passed to the Client constructor.
154154
type ClientOption func(o *clientOptions) error
155155

156+
// TransportDecorator is the type used to support customization of the HTTP
157+
// transport.
158+
type TransportDecorator func(http.RoundTripper) http.RoundTripper
159+
156160
type clientOptions struct {
157161
transport http.RoundTripper
162+
transportDecorator TransportDecorator
158163
timeout time.Duration
159164
rootSHA256 string
160165
rootFilename string
@@ -272,7 +277,8 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
272277
}
273278
}
274279

275-
return tr, nil
280+
// Wrap the transport using the decorator function if necessary
281+
return decorateRoundTripper(tr, o.transportDecorator), nil
276282
}
277283

278284
// WithTransport adds a custom transport to the Client. It will fail if a
@@ -287,6 +293,16 @@ func WithTransport(tr http.RoundTripper) ClientOption {
287293
}
288294
}
289295

296+
// WithTransportDecorator allows customization of the HTTP transport used by the
297+
// client. The provided function receives the configured [http.RoundTripper] and
298+
// can wrap it with additional functionality.
299+
func WithTransportDecorator(fn TransportDecorator) ClientOption {
300+
return func(o *clientOptions) error {
301+
o.transportDecorator = fn
302+
return nil
303+
}
304+
}
305+
290306
// WithInsecure adds a insecure transport that bypasses TLS verification.
291307
func WithInsecure() ClientOption {
292308
return func(o *clientOptions) error {
@@ -562,11 +578,12 @@ func WithProvisionerName(name string) ProvisionerOption {
562578

563579
// Client implements an HTTP client for the CA server.
564580
type Client struct {
565-
client *uaClient
566-
endpoint *url.URL
567-
retryFunc RetryFunc
568-
timeout time.Duration
569-
opts []ClientOption
581+
client *uaClient
582+
endpoint *url.URL
583+
retryFunc RetryFunc
584+
timeout time.Duration
585+
opts []ClientOption
586+
transportDecorator TransportDecorator
570587
}
571588

572589
// NewClient creates a new Client with the given endpoint and options.
@@ -587,11 +604,12 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) {
587604
}
588605

589606
return &Client{
590-
client: newClient(tr, o.timeout),
591-
endpoint: u,
592-
retryFunc: o.retryFunc,
593-
timeout: o.timeout,
594-
opts: opts,
607+
client: newClient(tr, o.timeout),
608+
endpoint: u,
609+
retryFunc: o.retryFunc,
610+
timeout: o.timeout,
611+
opts: opts,
612+
transportDecorator: o.transportDecorator,
595613
}, nil
596614
}
597615

@@ -1583,3 +1601,10 @@ func clientError(err error) error {
15831601
}
15841602
return fmt.Errorf("client request failed: %w", err)
15851603
}
1604+
1605+
func decorateRoundTripper(tr http.RoundTripper, td TransportDecorator) http.RoundTripper {
1606+
if td != nil {
1607+
return td(tr)
1608+
}
1609+
return tr
1610+
}

ca/client_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,43 @@ func TestClient_WithTimeout(t *testing.T) {
10451045
}
10461046
}
10471047

1048+
type decoratedRoundTripper func(*http.Request) (*http.Response, error)
1049+
1050+
func (rt decoratedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
1051+
return rt(req)
1052+
}
1053+
1054+
func TestClient_WithTransportDecorator(t *testing.T) {
1055+
var srv *httptest.Server
1056+
srv = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1057+
if strings.HasPrefix(r.RequestURI, "/root") {
1058+
render.JSONStatus(w, r, api.RootResponse{
1059+
RootPEM: api.NewCertificate(srv.Certificate()),
1060+
}, 200)
1061+
return
1062+
}
1063+
1064+
if s := r.Header.Get("X-Test-Header"); s != "" {
1065+
render.JSONStatus(w, r, api.HealthResponse{Status: s}, 200)
1066+
} else {
1067+
render.JSONStatus(w, r, api.HealthResponse{Status: "ok"}, 200)
1068+
}
1069+
}))
1070+
defer srv.Close()
1071+
1072+
fp := x509util.Fingerprint(srv.Certificate())
1073+
c, err := NewClient(srv.URL, WithRootSHA256(fp), WithTransportDecorator(func(tr http.RoundTripper) http.RoundTripper {
1074+
return decoratedRoundTripper(func(r *http.Request) (*http.Response, error) {
1075+
r.Header.Add("X-Test-Header", "some-data")
1076+
return tr.RoundTrip(r)
1077+
})
1078+
}))
1079+
require.NoError(t, err)
1080+
resp, err := c.Health()
1081+
require.NoError(t, err)
1082+
assert.Equal(t, "some-data", resp.Status)
1083+
}
1084+
10481085
func Test_enforceRequestID(t *testing.T) {
10491086
set := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
10501087
set.Header.Set("X-Request-Id", "already-set")

ca/tls.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
110110
return tlsConfig, nil
111111
}
112112

113-
func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options []TLSOption) (*tls.Config, *http.Transport, error) {
113+
func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options []TLSOption) (*tls.Config, http.RoundTripper, error) {
114114
cert, err := TLSCertificate(sign, pk)
115115
if err != nil {
116116
return nil, nil, err
@@ -133,14 +133,18 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
133133

134134
tr := getDefaultTransport(tlsConfig)
135135
tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
136-
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) //nolint:contextcheck // deeply nested context
136+
137+
// Add decorator if available, and use the resulting [http.RoundTripper]
138+
// going forward
139+
rt := decorateRoundTripper(tr, c.transportDecorator)
140+
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, rt, pk) //nolint:contextcheck // deeply nested context
137141

138142
// Update client transport
139-
c.SetTransport(tr)
143+
c.SetTransport(rt)
140144

141145
// Start renewer
142146
renewer.RunContext(ctx)
143-
return tlsConfig, tr, nil
147+
return tlsConfig, rt, nil
144148
}
145149

146150
// GetServerTLSConfig returns a tls.Config for server use configured with the
@@ -179,18 +183,23 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
179183
// Update renew function with transport
180184
tr := getDefaultTransport(tlsConfig)
181185
tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
182-
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) //nolint:contextcheck // deeply nested context
186+
187+
// Add decorator if available, and use the resulting [http.RoundTripper]
188+
// going forward
189+
rt := decorateRoundTripper(tr, c.transportDecorator)
190+
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, rt, pk) //nolint:contextcheck // deeply nested context
183191

184192
// Update client transport
185-
c.SetTransport(tr)
193+
c.SetTransport(rt)
186194

187195
// Start renewer
188196
renewer.RunContext(ctx)
189197
return tlsConfig, nil
190198
}
191199

192-
// Transport returns an http.Transport configured to use the client certificate from the sign response.
193-
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*http.Transport, error) {
200+
// Transport returns an [http.RoundTripper] configured to use the client
201+
// certificate from the sign response.
202+
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (http.RoundTripper, error) {
194203
_, tr, err := c.getClientTLSConfig(ctx, sign, pk, options)
195204
if err != nil {
196205
return nil, err
@@ -365,7 +374,7 @@ func getPEM(i interface{}) ([]byte, error) {
365374
return pem.EncodeToMemory(block), nil
366375
}
367376

368-
func getRenewFunc(ctx *TLSOptionCtx, client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc {
377+
func getRenewFunc(ctx *TLSOptionCtx, client *Client, tr http.RoundTripper, pk crypto.PrivateKey) RenewFunc {
369378
return func() (*tls.Certificate, error) {
370379
// Close connections in keep-alive state
371380
defer client.CloseIdleConnections()

ca/tls_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,10 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
266266

267267
// Transport
268268
client, sr, pk = signDuration(t, ca, clientDomain, 5*time.Second)
269-
tr1, err := client.Transport(context.Background(), sr, pk)
269+
tr, err := client.Transport(context.Background(), sr, pk)
270270
require.NoError(t, err)
271+
tr1, ok := tr.(*http.Transport)
272+
require.True(t, ok)
271273

272274
// Transport with tlsConfig
273275
client, sr, pk = signDuration(t, ca, clientDomain, 5*time.Second)

0 commit comments

Comments
 (0)