From eecc7b7c4cc626c624bc55e8db0330dce7dd1a6d Mon Sep 17 00:00:00 2001 From: Nick Stenning Date: Sat, 6 Apr 2024 14:26:48 +0200 Subject: [PATCH 1/2] Implement RFC9530 Content-Digest as an http.RoundTripper This implements an http.RoundTripper which transparently calculates RFC9530-compatible Content-Digest headers for outgoing requests. This header will be needed if we in future wish to sign request bodies with RFC9421. --- http/digest/transport.go | 55 +++++++++++++++++ http/digest/transport_test.go | 111 ++++++++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+) create mode 100644 http/digest/transport.go create mode 100644 http/digest/transport_test.go diff --git a/http/digest/transport.go b/http/digest/transport.go new file mode 100644 index 0000000..bdc4128 --- /dev/null +++ b/http/digest/transport.go @@ -0,0 +1,55 @@ +// Package digest implements support for HTTP Content-Digest headers as +// described in [RFC 9530]. Currently it only supports adding SHA-512 digests to +// outgoing requests via the Transport type. +// +// [RFC 9530]: https://www.rfc-editor.org/rfc/rfc9530.html +package digest + +import ( + "bytes" + "crypto/sha512" + "encoding/base64" + "io" + "net/http" +) + +// Transport is an implementation of http.RoundTripper that automatically adds +// an RFC 9530 Content-Digest header to outgoing requests. +// +// Note: This transport will necessarily buffer the request body in memory in +// order to calculate the digest. +type Transport struct { + http.RoundTripper +} + +func NewTransport(t http.RoundTripper) *Transport { + return &Transport{ + RoundTripper: t, + } +} + +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + h := sha512.New() + + // RoundTrip must not modify the original request. + req = req.Clone(req.Context()) + + if req.Body != nil { + // RoundTrip must close the request body even in the event of an error. + defer req.Body.Close() + + body := io.TeeReader(req.Body, h) + + var buf bytes.Buffer + if _, err := io.Copy(&buf, body); err != nil { + return nil, err + } + + req.Body = io.NopCloser(&buf) + } + + digest := base64.StdEncoding.EncodeToString(h.Sum(nil)) + req.Header.Set("Content-Digest", "sha-512=:"+digest+":") + + return t.RoundTripper.RoundTrip(req) +} diff --git a/http/digest/transport_test.go b/http/digest/transport_test.go new file mode 100644 index 0000000..266ffe3 --- /dev/null +++ b/http/digest/transport_test.go @@ -0,0 +1,111 @@ +package digest_test + +import ( + "bytes" + "io" + "math/rand/v2" + "net/http" + "net/http/httptest" + "testing" + + "github.com/replicate/go/http/digest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func serverExpectingDigest(t *testing.T, digest string) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expected := "sha-512=:" + digest + ":" + received := r.Header.Get("Content-Digest") + + assert.Equal(t, expected, received) + + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status":"OK"}`)) + })) +} + +// Generate a predictable seeded random payload of a given size +func generatePayload(s1, s2 uint64, size int) io.ReadCloser { + r := rand.New(rand.NewPCG(s1, s2)) + data := make([]byte, size) + for i := 0; i < size; i++ { + data[i] = byte(r.IntN(256)) + } + return io.NopCloser(bytes.NewReader(data)) +} + +func TestTransport(t *testing.T) { + testcases := []struct { + Name string + Body io.ReadCloser + Digest string + }{ + { + Name: "nil body", + Body: nil, + Digest: "z4PhNX7vuL3xVChQ1m2AB9Yg5AULVxXcg/SpIdNs6c5H0NE8XYXysP+DGNKHfuwvY7kxvUdBeoGlODJ6+SfaPg==", + }, + { + Name: "empty body", + Body: io.NopCloser(bytes.NewReader([]byte{})), + Digest: "z4PhNX7vuL3xVChQ1m2AB9Yg5AULVxXcg/SpIdNs6c5H0NE8XYXysP+DGNKHfuwvY7kxvUdBeoGlODJ6+SfaPg==", + }, + { + Name: "hello world", + Body: io.NopCloser(bytes.NewReader([]byte("hello world"))), + Digest: "MJ7MSJwS1utMxA9QyQLytNDtd+5RGnx6m808qG1M2G+YndNbxf9JlnDaNCVbRbDP2DDoH2Bdz33FVC6TrpzXbw==", + }, + { + Name: "large body (128KB)", + Body: generatePayload(42, 42, 128*1024), + Digest: "fV+7qAxDBpKPaXsFZogCBpSROb5F+j/5kvIIPWMXQUcyiOiL/4YCbo9HwybsuD1rYQ7sBAEW4HnlHrrkSYEI6w==", + }, + } + + client := &http.Client{ + Transport: digest.NewTransport(http.DefaultTransport), + } + + for _, tc := range testcases { + t.Run(tc.Name, func(t *testing.T) { + server := serverExpectingDigest(t, tc.Digest) + defer server.Close() + + req, err := http.NewRequest("GET", server.URL, tc.Body) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + + defer resp.Body.Close() + }) + } +} + +type nopTransport struct{} + +func (tr *nopTransport) RoundTrip(_ *http.Request) (*http.Response, error) { + return &http.Response{}, nil +} + +func BenchmarkTransport(b *testing.B) { + n := 128 * 1024 + + b.ReportAllocs() + b.SetBytes(int64(n)) + + transport := digest.NewTransport(&nopTransport{}) + + requests := make([]*http.Request, b.N) + for i := 0; i < b.N; i++ { + req, err := http.NewRequest("GET", "http://example.com", generatePayload(456, uint64(i), n)) + require.NoError(b, err) + requests[i] = req + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = transport.RoundTrip(requests[i]) + } +} From 2591fa7c69b06afbfaf15efa787dd08ea53ccdb8 Mon Sep 17 00:00:00 2001 From: Nick Stenning Date: Sat, 6 Apr 2024 15:00:24 +0200 Subject: [PATCH 2/2] Pool buffers and hashes Generating new byte buffers for every single processed request is quite wasteful and generates a lot of work for the GC. Similarly, the hash.Hash struct can be reused, although this provides only a small additional benefit. Before: goos: darwin goarch: arm64 pkg: github.com/replicate/go/http/digest BenchmarkTransport-10 9200 123811 ns/op 1058.65 MB/s 525456 B/op 24 allocs/op PASS ok github.com/replicate/go/http/digest 6.684s After: goos: darwin goarch: arm64 pkg: github.com/replicate/go/http/digest BenchmarkTransport-10 12412 95642 ns/op 1370.45 MB/s 1450 B/op 12 allocs/op PASS ok github.com/replicate/go/http/digest 14.619s --- http/digest/transport.go | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/http/digest/transport.go b/http/digest/transport.go index bdc4128..c0f2f48 100644 --- a/http/digest/transport.go +++ b/http/digest/transport.go @@ -9,8 +9,10 @@ import ( "bytes" "crypto/sha512" "encoding/base64" + "hash" "io" "net/http" + "sync" ) // Transport is an implementation of http.RoundTripper that automatically adds @@ -20,16 +22,32 @@ import ( // order to calculate the digest. type Transport struct { http.RoundTripper + + bufPool sync.Pool + hashPool sync.Pool } func NewTransport(t http.RoundTripper) *Transport { return &Transport{ RoundTripper: t, + + bufPool: sync.Pool{ + New: func() any { + return new(bytes.Buffer) + }, + }, + hashPool: sync.Pool{ + New: func() any { + return sha512.New() + }, + }, } } func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { - h := sha512.New() + h := t.hashPool.Get().(hash.Hash) + h.Reset() + defer t.hashPool.Put(h) // RoundTrip must not modify the original request. req = req.Clone(req.Context()) @@ -40,12 +58,14 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { body := io.TeeReader(req.Body, h) - var buf bytes.Buffer - if _, err := io.Copy(&buf, body); err != nil { + buf := t.bufPool.Get().(*bytes.Buffer) + buf.Reset() + defer t.bufPool.Put(buf) + if _, err := io.Copy(buf, body); err != nil { return nil, err } - req.Body = io.NopCloser(&buf) + req.Body = io.NopCloser(buf) } digest := base64.StdEncoding.EncodeToString(h.Sum(nil))