Skip to content

Commit 45ef1d4

Browse files
Merge pull request #45 from snyk/feat/file-upload-gzip-compression
feat: gzip compress file upload requests
2 parents d4c1727 + cb0ac4b commit 45ef1d4

File tree

6 files changed

+274
-59
lines changed

6 files changed

+274
-59
lines changed

internal/fileupload/lowlevel/client.go

Lines changed: 56 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,20 @@ const FileSizeLimit = 50_000_000 // arbitrary number, chosen to support max size
4444
// FileCountLimit specifies the maximum number of files allowed in a single upload.
4545
const FileCountLimit = 100 // arbitrary number, will need to be re-evaluated
4646

47-
// ContentType is the HTTP header name for content type.
48-
const ContentType = "Content-Type"
49-
5047
// NewClient creates a new file upload client with the given configuration and options.
5148
func NewClient(cfg Config, opts ...Opt) *HTTPClient {
52-
c := HTTPClient{cfg, http.DefaultClient}
49+
httpClient := &http.Client{
50+
Transport: http.DefaultTransport,
51+
}
52+
c := HTTPClient{cfg, httpClient}
5353

5454
for _, opt := range opts {
5555
opt(&c)
5656
}
5757

58+
crt := NewCompressionRoundTripper(c.httpClient.Transport)
59+
c.httpClient.Transport = crt
60+
5861
return &c
5962
}
6063

@@ -80,7 +83,7 @@ func (c *HTTPClient) CreateRevision(ctx context.Context, orgID OrgID) (*UploadRe
8083
url := fmt.Sprintf("%s/hidden/orgs/%s/upload_revisions?version=%s", c.cfg.BaseURL, orgID, APIVersion)
8184
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, buff)
8285
if err != nil {
83-
return nil, fmt.Errorf("failed to create request body: %w", err)
86+
return nil, fmt.Errorf("failed to create revision request: %w", err)
8487
}
8588
req.Header.Set(ContentType, "application/vnd.api+json")
8689

@@ -91,7 +94,7 @@ func (c *HTTPClient) CreateRevision(ctx context.Context, orgID OrgID) (*UploadRe
9194
defer res.Body.Close()
9295

9396
if res.StatusCode != http.StatusCreated {
94-
return nil, c.handleUnexpectedStatusCodes(res.Body, res.StatusCode, res.Status, "create upload revision")
97+
return nil, handleUnexpectedStatusCodes(res.Body, res.StatusCode, res.Status, "create upload revision")
9598
}
9699

97100
var respBody UploadRevisionResponseBody
@@ -112,44 +115,23 @@ func (c *HTTPClient) UploadFiles(ctx context.Context, orgID OrgID, revisionID Re
112115
return ErrEmptyRevisionID
113116
}
114117

115-
if len(files) > FileCountLimit {
116-
return NewFileCountLimitError(len(files), FileCountLimit)
118+
if err := validateFiles(files); err != nil {
119+
return err
117120
}
118121

119-
if len(files) == 0 {
120-
return ErrNoFilesProvided
121-
}
122-
123-
for _, file := range files {
124-
fileInfo, err := file.File.Stat()
125-
if err != nil {
126-
return NewFileAccessError(file.Path, err)
127-
}
128-
129-
if fileInfo.IsDir() {
130-
return NewDirectoryError(file.Path)
131-
}
122+
// Create pipe for multipart data
123+
pipeReader, pipeWriter := io.Pipe()
124+
defer pipeReader.Close()
132125

133-
if fileInfo.Size() > FileSizeLimit {
134-
return NewFileSizeLimitError(file.Path, fileInfo.Size(), FileSizeLimit)
135-
}
136-
}
137-
138-
// Create pipe for streaming multipart data
139-
pReader, pWriter := io.Pipe()
140-
mpartWriter := multipart.NewWriter(pWriter)
126+
mpartWriter := multipart.NewWriter(pipeWriter)
141127

142-
// Start goroutine to write multipart data
143-
go c.streamFilesToPipe(pWriter, mpartWriter, files)
128+
go streamFilesToPipe(pipeWriter, mpartWriter, files)
144129

145-
// Create HTTP request with streaming body
146130
url := fmt.Sprintf("%s/hidden/orgs/%s/upload_revisions/%s/files?version=%s", c.cfg.BaseURL, orgID, revisionID, APIVersion)
147-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, pReader)
131+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, pipeReader)
148132
if err != nil {
149-
pReader.Close()
150133
return fmt.Errorf("failed to create upload files request: %w", err)
151134
}
152-
153135
req.Header.Set(ContentType, mpartWriter.FormDataContentType())
154136

155137
res, err := c.httpClient.Do(req)
@@ -159,18 +141,21 @@ func (c *HTTPClient) UploadFiles(ctx context.Context, orgID OrgID, revisionID Re
159141
defer res.Body.Close()
160142

161143
if res.StatusCode != http.StatusNoContent {
162-
return c.handleUnexpectedStatusCodes(res.Body, res.StatusCode, res.Status, "upload files")
144+
return handleUnexpectedStatusCodes(res.Body, res.StatusCode, res.Status, "upload files")
163145
}
164146

165147
return nil
166148
}
167149

168-
func (c *HTTPClient) streamFilesToPipe(pWriter *io.PipeWriter, mpartWriter *multipart.Writer, files []UploadFile) {
150+
// streamFilesToPipe writes files to the multipart form.
151+
func streamFilesToPipe(pipeWriter *io.PipeWriter, mpartWriter *multipart.Writer, files []UploadFile) {
169152
var streamError error
170153
defer func() {
171-
pWriter.CloseWithError(streamError)
154+
if closeErr := mpartWriter.Close(); closeErr != nil && streamError == nil {
155+
streamError = closeErr
156+
}
157+
pipeWriter.CloseWithError(streamError)
172158
}()
173-
defer mpartWriter.Close()
174159

175160
for _, file := range files {
176161
// Create form file part
@@ -180,14 +165,41 @@ func (c *HTTPClient) streamFilesToPipe(pWriter *io.PipeWriter, mpartWriter *mult
180165
return
181166
}
182167

183-
_, err = io.Copy(part, file.File)
184-
if err != nil {
168+
if _, err := io.Copy(part, file.File); err != nil {
185169
streamError = fmt.Errorf("failed to copy file content for %s: %w", file.Path, err)
186170
return
187171
}
188172
}
189173
}
190174

175+
// validateFiles validates the files before upload.
176+
func validateFiles(files []UploadFile) error {
177+
if len(files) > FileCountLimit {
178+
return NewFileCountLimitError(len(files), FileCountLimit)
179+
}
180+
181+
if len(files) == 0 {
182+
return ErrNoFilesProvided
183+
}
184+
185+
for _, file := range files {
186+
fileInfo, err := file.File.Stat()
187+
if err != nil {
188+
return NewFileAccessError(file.Path, err)
189+
}
190+
191+
if fileInfo.IsDir() {
192+
return NewDirectoryError(file.Path)
193+
}
194+
195+
if fileInfo.Size() > FileSizeLimit {
196+
return NewFileSizeLimitError(file.Path, fileInfo.Size(), FileSizeLimit)
197+
}
198+
}
199+
200+
return nil
201+
}
202+
191203
// SealRevision seals the specified upload revision, marking it as complete.
192204
func (c *HTTPClient) SealRevision(ctx context.Context, orgID OrgID, revisionID RevisionID) (*SealUploadRevisionResponseBody, error) {
193205
if orgID == uuid.Nil {
@@ -215,7 +227,7 @@ func (c *HTTPClient) SealRevision(ctx context.Context, orgID OrgID, revisionID R
215227
url := fmt.Sprintf("%s/hidden/orgs/%s/upload_revisions/%s?version=%s", c.cfg.BaseURL, orgID, revisionID, APIVersion)
216228
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, buff)
217229
if err != nil {
218-
return nil, fmt.Errorf("failed to create request body: %w", err)
230+
return nil, fmt.Errorf("failed to create seal request: %w", err)
219231
}
220232
req.Header.Set(ContentType, "application/vnd.api+json")
221233

@@ -226,7 +238,7 @@ func (c *HTTPClient) SealRevision(ctx context.Context, orgID OrgID, revisionID R
226238
defer res.Body.Close()
227239

228240
if res.StatusCode != http.StatusOK {
229-
return nil, c.handleUnexpectedStatusCodes(res.Body, res.StatusCode, res.Status, "seal upload revision")
241+
return nil, handleUnexpectedStatusCodes(res.Body, res.StatusCode, res.Status, "seal upload revision")
230242
}
231243

232244
var respBody SealUploadRevisionResponseBody
@@ -237,7 +249,7 @@ func (c *HTTPClient) SealRevision(ctx context.Context, orgID OrgID, revisionID R
237249
return &respBody, nil
238250
}
239251

240-
func (c *HTTPClient) handleUnexpectedStatusCodes(body io.ReadCloser, statusCode int, status, operation string) error {
252+
func handleUnexpectedStatusCodes(body io.ReadCloser, statusCode int, status, operation string) error {
241253
bts, err := io.ReadAll(body)
242254
if err != nil {
243255
return fmt.Errorf("failed to read response body: %w", err)

internal/fileupload/lowlevel/client_test.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package lowlevel_fileupload_test
22

33
import (
4+
"compress/gzip"
45
"context"
56
"errors"
67
"fmt"
@@ -25,6 +26,7 @@ func TestClient_CreateRevision(t *testing.T) {
2526
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2627
assert.Equal(t, http.MethodPost, r.Method)
2728
assert.Equal(t, "application/vnd.api+json", r.Header.Get("Content-Type"))
29+
assert.Equal(t, "gzip", r.Header.Get("Content-Encoding"))
2830
assert.Equal(t, fmt.Sprintf("/hidden/orgs/%s/upload_revisions", orgID), r.URL.Path)
2931
assert.Equal(t, "2024-10-15", r.URL.Query().Get("version"))
3032

@@ -88,6 +90,7 @@ func TestClient_UploadFiles(t *testing.T) {
8890
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
8991
assert.Equal(t, http.MethodPost, r.Method)
9092
assert.Contains(t, r.Header.Get("Content-Type"), "multipart/form-data")
93+
assert.Equal(t, "gzip", r.Header.Get("Content-Encoding"))
9194
assert.Equal(t, fmt.Sprintf("/hidden/orgs/%s/upload_revisions/%s/files", orgID, revID), r.URL.Path)
9295
assert.Equal(t, "2024-10-15", r.URL.Query().Get("version"))
9396

@@ -98,7 +101,9 @@ func TestClient_UploadFiles(t *testing.T) {
98101
boundary := params["boundary"]
99102
require.NotEmpty(t, boundary, "multipart boundary should be present")
100103

101-
reader := multipart.NewReader(r.Body, boundary)
104+
gzipReader, err := gzip.NewReader(r.Body)
105+
require.NoError(t, err)
106+
reader := multipart.NewReader(gzipReader, boundary)
102107

103108
// Read the first (and should be only) part
104109
part, err := reader.NextPart()
@@ -152,6 +157,7 @@ func TestClient_UploadFiles_MultipleFiles(t *testing.T) {
152157
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
153158
assert.Equal(t, http.MethodPost, r.Method)
154159
assert.Contains(t, r.Header.Get("Content-Type"), "multipart/form-data")
160+
assert.Equal(t, "gzip", r.Header.Get("Content-Encoding"))
155161
assert.Equal(t, fmt.Sprintf("/hidden/orgs/%s/upload_revisions/%s/files", orgID, revID), r.URL.Path)
156162
assert.Equal(t, "2024-10-15", r.URL.Query().Get("version"))
157163

@@ -162,7 +168,9 @@ func TestClient_UploadFiles_MultipleFiles(t *testing.T) {
162168
boundary := params["boundary"]
163169
require.NotEmpty(t, boundary, "multipart boundary should be present")
164170

165-
reader := multipart.NewReader(r.Body, boundary)
171+
gzipReader, err := gzip.NewReader(r.Body)
172+
require.NoError(t, err)
173+
reader := multipart.NewReader(gzipReader, boundary)
166174
filesReceived := make(map[string]string)
167175

168176
// Read all parts
@@ -287,7 +295,7 @@ func TestClient_UploadFiles_FileSizeLimit(t *testing.T) {
287295
assert.Error(t, err)
288296
var fileSizeErr *lowlevel_fileupload.FileSizeLimitError
289297
assert.ErrorAs(t, err, &fileSizeErr)
290-
assert.Equal(t, "large_file.txt", fileSizeErr.FileName)
298+
assert.Equal(t, "large_file.txt", fileSizeErr.FilePath)
291299
assert.Equal(t, int64(lowlevel_fileupload.FileSizeLimit+1), fileSizeErr.FileSize)
292300
assert.Equal(t, int64(lowlevel_fileupload.FileSizeLimit), fileSizeErr.Limit)
293301
}
@@ -375,6 +383,7 @@ func TestClient_SealRevision(t *testing.T) {
375383
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
376384
assert.Equal(t, http.MethodPatch, r.Method)
377385
assert.Equal(t, "application/vnd.api+json", r.Header.Get("Content-Type"))
386+
assert.Equal(t, "gzip", r.Header.Get("Content-Encoding"))
378387
assert.Equal(t, fmt.Sprintf("/hidden/orgs/%s/upload_revisions/%s", orgID, revID), r.URL.Path)
379388
assert.Equal(t, "2024-10-15", r.URL.Query().Get("version"))
380389

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package lowlevel_fileupload //nolint:revive // underscore naming is intentional for this internal package.
2+
3+
import (
4+
"compress/gzip"
5+
"io"
6+
"net/http"
7+
)
8+
9+
// CompressionRoundTripper is an http.RoundTripper that automatically compresses
10+
// request bodies using gzip compression. It wraps another RoundTripper and adds
11+
// Content-Encoding: gzip header while removing Content-Length to allow proper
12+
// compression handling.
13+
type CompressionRoundTripper struct {
14+
defaultRoundTripper http.RoundTripper
15+
}
16+
17+
// NewCompressionRoundTripper creates a new CompressionRoundTripper that wraps
18+
// the provided RoundTripper. If drt is nil, http.DefaultTransport is used.
19+
// All HTTP requests with a body will be automatically compressed using gzip.
20+
func NewCompressionRoundTripper(drt http.RoundTripper) *CompressionRoundTripper {
21+
rt := drt
22+
if rt == nil {
23+
rt = http.DefaultTransport
24+
}
25+
return &CompressionRoundTripper{rt}
26+
}
27+
28+
// compressRequestBody wraps the given reader with gzip compression.
29+
func compressRequestBody(body io.Reader) io.ReadCloser {
30+
pipeReader, pipeWriter := io.Pipe()
31+
32+
go func() {
33+
var err error
34+
gzWriter := gzip.NewWriter(pipeWriter)
35+
36+
_, err = io.Copy(gzWriter, body)
37+
38+
if closeErr := gzWriter.Close(); closeErr != nil && err == nil {
39+
err = closeErr
40+
}
41+
pipeWriter.CloseWithError(err)
42+
}()
43+
44+
return pipeReader
45+
}
46+
47+
// RoundTrip implements the http.RoundTripper interface. It compresses the request
48+
// body using gzip if a body is present, sets the Content-Encoding header to "gzip",
49+
// and removes the Content-Length header to allow Go's HTTP client to calculate
50+
// the correct length after compression. Requests without a body are passed through
51+
// unchanged to the wrapped RoundTripper.
52+
func (crt *CompressionRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
53+
if r.Body == nil || r.Body == http.NoBody {
54+
//nolint:wrapcheck // No need to wrap the error here.
55+
return crt.defaultRoundTripper.RoundTrip(r)
56+
}
57+
58+
compressedBody := compressRequestBody(r.Body)
59+
60+
r.Body = compressedBody
61+
r.Header.Set(ContentEncoding, "gzip")
62+
r.Header.Del("Content-Length")
63+
r.ContentLength = -1 // Let Go calculate the length
64+
65+
//nolint:wrapcheck // No need to wrap the error here.
66+
return crt.defaultRoundTripper.RoundTrip(r)
67+
}

0 commit comments

Comments
 (0)