@@ -110,7 +110,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
110110 return tlsConfig , nil
111111}
112112
113- func (c * Client ) getClientTLSConfig (ctx context.Context , sign * api.SignResponse , pk crypto.PrivateKey , options []TLSOption ) (* tls.Config , * http.Transport , error ) {
113+ func (c * Client ) getClientTLSConfig (ctx context.Context , sign * api.SignResponse , pk crypto.PrivateKey , options []TLSOption ) (* tls.Config , http.RoundTripper , error ) {
114114 cert , err := TLSCertificate (sign , pk )
115115 if err != nil {
116116 return nil , nil , err
@@ -133,14 +133,18 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
133133
134134 tr := getDefaultTransport (tlsConfig )
135135 tr .DialTLSContext = c .buildDialTLSContext (tlsCtx )
136- renewer .RenewCertificate = getRenewFunc (tlsCtx , c , tr , pk ) //nolint:contextcheck // deeply nested context
136+
137+ // Add decorator if available, and use the resulting [http.RoundTripper]
138+ // going forward
139+ rt := decorateRoundTripper (tr , c .transportDecorator )
140+ renewer .RenewCertificate = getRenewFunc (tlsCtx , c , rt , pk ) //nolint:contextcheck // deeply nested context
137141
138142 // Update client transport
139- c .SetTransport (tr )
143+ c .SetTransport (rt )
140144
141145 // Start renewer
142146 renewer .RunContext (ctx )
143- return tlsConfig , tr , nil
147+ return tlsConfig , rt , nil
144148}
145149
146150// GetServerTLSConfig returns a tls.Config for server use configured with the
@@ -179,18 +183,23 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
179183 // Update renew function with transport
180184 tr := getDefaultTransport (tlsConfig )
181185 tr .DialTLSContext = c .buildDialTLSContext (tlsCtx )
182- renewer .RenewCertificate = getRenewFunc (tlsCtx , c , tr , pk ) //nolint:contextcheck // deeply nested context
186+
187+ // Add decorator if available, and use the resulting [http.RoundTripper]
188+ // going forward
189+ rt := decorateRoundTripper (tr , c .transportDecorator )
190+ renewer .RenewCertificate = getRenewFunc (tlsCtx , c , rt , pk ) //nolint:contextcheck // deeply nested context
183191
184192 // Update client transport
185- c .SetTransport (tr )
193+ c .SetTransport (rt )
186194
187195 // Start renewer
188196 renewer .RunContext (ctx )
189197 return tlsConfig , nil
190198}
191199
192- // Transport returns an http.Transport configured to use the client certificate from the sign response.
193- func (c * Client ) Transport (ctx context.Context , sign * api.SignResponse , pk crypto.PrivateKey , options ... TLSOption ) (* http.Transport , error ) {
200+ // Transport returns an [http.RoundTripper] configured to use the client
201+ // certificate from the sign response.
202+ func (c * Client ) Transport (ctx context.Context , sign * api.SignResponse , pk crypto.PrivateKey , options ... TLSOption ) (http.RoundTripper , error ) {
194203 _ , tr , err := c .getClientTLSConfig (ctx , sign , pk , options )
195204 if err != nil {
196205 return nil , err
@@ -365,7 +374,7 @@ func getPEM(i interface{}) ([]byte, error) {
365374 return pem .EncodeToMemory (block ), nil
366375}
367376
368- func getRenewFunc (ctx * TLSOptionCtx , client * Client , tr * http.Transport , pk crypto.PrivateKey ) RenewFunc {
377+ func getRenewFunc (ctx * TLSOptionCtx , client * Client , tr http.RoundTripper , pk crypto.PrivateKey ) RenewFunc {
369378 return func () (* tls.Certificate , error ) {
370379 // Close connections in keep-alive state
371380 defer client .CloseIdleConnections ()
0 commit comments