Skip to content

Commit dd55028

Browse files
authored
Add multipart download (#4)
* Add multipart download * actually exit fast on non-retryable errors * Add multipart progress reporting
1 parent b7033e1 commit dd55028

File tree

6 files changed

+158
-46
lines changed

6 files changed

+158
-46
lines changed

.github/workflows/ci.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
name: ci
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v4
14+
15+
- name: Set up Go
16+
uses: actions/setup-go@v5
17+
18+
- name: Run tests
19+
run: go test -v -race ./...

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55
dl-pipe https://example.invalid/my-file.tar | tar x
66
```
77

8+
You may also provide the parts of a multipart tar file and it will be reassembled.
9+
10+
```
11+
dl-pipe https://example.invalid/my-file.tar.part1 https://example.invalid/my-file.tar.part2 https://example.invalid/my-file.tar.part3 | tar x
12+
```
13+
14+
We use this to workaround the 5TB size limit of most object storage providers.
15+
816
We also provide an expected hash via the `-hash` option to ensure that the download content is correct. Make sure you set `set -eo pipefail` to ensure your script stops on errors.
917

1018
Install with `go install github.com/zeta-chain/dl-pipe@latest`.

cmd/dl-pipe/main.go

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ const progressFuncInterval = time.Second * 10
7171

7272
func getProgressFunc() dlpipe.ProgressFunc {
7373
prevLength := uint64(0)
74-
return func(currentLength uint64, totalLength uint64) {
74+
return func(currentLength uint64, totalLength uint64, currentPart int, totalParts int) {
7575
currentLengthStr := humanize.Bytes(currentLength)
7676
totalLengthStr := humanize.Bytes(totalLength)
7777

@@ -81,7 +81,12 @@ func getProgressFunc() dlpipe.ProgressFunc {
8181

8282
percent := float64(currentLength) / float64(totalLength) * 100
8383

84-
fmt.Fprintf(os.Stderr, "Downloaded %s of %s (%.1f%%) at %s/s\n", currentLengthStr, totalLengthStr, percent, rateStr)
84+
partStr := ""
85+
if totalParts > 1 {
86+
partStr = fmt.Sprintf(" (part %d of %d)", currentPart+1, totalParts)
87+
}
88+
89+
fmt.Fprintf(os.Stderr, "Downloaded %s of %s (%.1f%%) at %s/s%s\n", currentLengthStr, totalLengthStr, percent, rateStr, partStr)
8590
}
8691
}
8792

@@ -101,9 +106,9 @@ func main() {
101106
flag.BoolVar(&progress, "progress", false, "Show download progress")
102107
flag.Parse()
103108

104-
url := flag.Arg(0)
105-
if url == "" {
106-
fmt.Fprintf(os.Stderr, ("URL is required"))
109+
urls := flag.Args()
110+
if len(urls) == 0 {
111+
fmt.Fprintf(os.Stderr, ("URL(s) are required"))
107112
os.Exit(1)
108113
}
109114

@@ -119,9 +124,9 @@ func main() {
119124
headerMap[parts[0]] = parts[1]
120125
}
121126

122-
err := dlpipe.DownloadURL(
127+
err := dlpipe.DownloadURLMultipart(
123128
ctx,
124-
url,
129+
urls,
125130
os.Stdout,
126131
dlpipe.WithHeaders(headerMap),
127132
getHashOpt(hash),

download.go

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ package dlpipe
33
import (
44
"bytes"
55
"context"
6+
"errors"
67
"fmt"
78
"hash"
89
"io"
910
"net/http"
1011
"strings"
12+
"sync"
1113
"time"
1214

1315
"github.com/miolini/datacounter"
@@ -57,7 +59,7 @@ func WithHeaders(headers map[string]string) DownloadOpt {
5759
}
5860
}
5961

60-
type ProgressFunc func(currentLength uint64, totalLength uint64)
62+
type ProgressFunc func(currentLength, totalLength uint64, currentPart, totalParts int)
6163

6264
func WithProgressFunc(progressFunc ProgressFunc, interval time.Duration) DownloadOpt {
6365
return func(d *downloader) {
@@ -115,7 +117,7 @@ func DefaultRetryParameters() RetryParameters {
115117

116118
type downloader struct {
117119
// these fields are set once
118-
url string
120+
urls []string
119121
writer *datacounter.WriterCounter
120122
httpClient *http.Client
121123
retryParameters RetryParameters
@@ -129,6 +131,9 @@ type downloader struct {
129131

130132
// these fields are updated at runtime
131133
contentLength int64
134+
urlsPosition int
135+
136+
sync.RWMutex
132137
}
133138

134139
func (d *downloader) progressReportLoop(ctx context.Context) {
@@ -137,15 +142,19 @@ func (d *downloader) progressReportLoop(ctx context.Context) {
137142
for {
138143
select {
139144
case <-t.C:
140-
d.progressFunc(d.writer.Count(), uint64(d.contentLength))
145+
d.RLock()
146+
d.progressFunc(d.writer.Count(), uint64(d.contentLength), d.urlsPosition, d.totalPartCount())
147+
d.RUnlock()
141148
case <-ctx.Done():
142149
return
143150
}
144151
}
145152
}
146153

147154
func (d *downloader) runInner(ctx context.Context) (io.ReadCloser, error) {
148-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, nil)
155+
d.RLock()
156+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.urls[d.urlsPosition], nil)
157+
d.RUnlock()
149158
if err != nil {
150159
return nil, NonRetryableWrapf("create request: %w", err)
151160
}
@@ -176,7 +185,9 @@ func (d *downloader) runInner(ctx context.Context) (io.ReadCloser, error) {
176185
}
177186

178187
if resp.StatusCode != http.StatusPartialContent {
179-
return nil, NonRetryableWrapf("unexpected status code on subsequent read: %d", resp.StatusCode)
188+
// this error should be retried since cloudflare r2 sometimes ignores the range request and
189+
// returns 200
190+
return nil, fmt.Errorf("unexpected status code on subsequent read: %d", resp.StatusCode)
180191
}
181192

182193
// Validate we are receiving the right portion of partial content
@@ -212,15 +223,23 @@ func (d *downloader) run(ctx context.Context) error {
212223
if d.progressFunc != nil {
213224
go d.progressReportLoop(ctx)
214225
}
215-
for {
226+
d.resetWriterPosition()
227+
228+
for d.urlsPosition < d.totalPartCount() {
216229
body, err := d.runInner(ctx)
217-
if err != nil {
218-
return err
219-
}
220-
defer body.Close()
221-
_, err = io.Copy(d.writer, body)
222230
if err == nil {
223-
break
231+
defer body.Close()
232+
_, err = io.Copy(d.writer, body)
233+
if err == nil {
234+
d.Lock()
235+
d.urlsPosition++
236+
d.resetWriterPosition()
237+
d.Unlock()
238+
continue
239+
}
240+
}
241+
if errors.Is(err, ErrNonRetryable{}) {
242+
return err
224243
}
225244
err = d.retryParameters.Wait(ctx, d.writer.Count())
226245
if err != nil {
@@ -236,9 +255,22 @@ func (d *downloader) run(ctx context.Context) error {
236255
return nil
237256
}
238257

258+
func (d *downloader) resetWriterPosition() {
259+
d.writer = datacounter.NewWriterCounter(d.tmpWriter)
260+
d.contentLength = 0
261+
}
262+
263+
func (d *downloader) totalPartCount() int {
264+
return len(d.urls)
265+
}
266+
239267
func DownloadURL(ctx context.Context, url string, writer io.Writer, opts ...DownloadOpt) error {
268+
return DownloadURLMultipart(ctx, []string{url}, writer, opts...)
269+
}
270+
271+
func DownloadURLMultipart(ctx context.Context, urls []string, writer io.Writer, opts ...DownloadOpt) error {
240272
d := &downloader{
241-
url: url,
273+
urls: urls,
242274
tmpWriter: writer,
243275
httpClient: &http.Client{
244276
Transport: &http.Transport{
@@ -254,6 +286,5 @@ func DownloadURL(ctx context.Context, url string, writer io.Writer, opts ...Down
254286
}
255287
opt(d)
256288
}
257-
d.writer = datacounter.NewWriterCounter(d.tmpWriter)
258289
return d.run(ctx)
259290
}

download_test.go

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ import (
44
"context"
55
"crypto/rand"
66
"crypto/sha256"
7+
"errors"
78
"fmt"
89
"io"
910
"log"
1011
"net/http"
1112
"net/http/httptest"
1213
"os"
14+
"path/filepath"
1315
"sync"
1416
"testing"
1517

@@ -26,12 +28,12 @@ func TestUninterruptedDownload(t *testing.T) {
2628
r := require.New(t)
2729
ctx := context.Background()
2830

29-
serverURL, expectedHash, cleanup := serveInterruptedTestFile(t, fileSize, 0)
31+
serverURLs, expectedHash, cleanup := serveInterruptedTestFiles(t, fileSize, 0, 1)
3032
defer cleanup()
3133

3234
hasher := sha256.New()
3335

34-
err := DownloadURL(ctx, serverURL, io.Discard, WithExpectedHash(hasher, expectedHash))
36+
err := DownloadURLMultipart(ctx, serverURLs, io.Discard, WithExpectedHash(hasher, expectedHash))
3537
r.NoError(err)
3638

3739
givenHash := hasher.Sum(nil)
@@ -42,54 +44,96 @@ func TestUninterruptedMismatch(t *testing.T) {
4244
r := require.New(t)
4345
ctx := context.Background()
4446

45-
serverURL, _, cleanup := serveInterruptedTestFile(t, fileSize, 0)
47+
serverURLs, _, cleanup := serveInterruptedTestFiles(t, fileSize, 0, 1)
4648
defer cleanup()
4749

4850
hasher := sha256.New()
4951

50-
err := DownloadURL(ctx, serverURL, io.Discard, WithExpectedHash(hasher, []byte{}))
52+
err := DownloadURLMultipart(ctx, serverURLs, io.Discard, WithExpectedHash(hasher, []byte{}))
5153
r.Error(err)
5254
}
5355

5456
func TestInterruptedDownload(t *testing.T) {
5557
r := require.New(t)
5658
ctx := context.Background()
5759

58-
serverURL, expectedHash, cleanup := serveInterruptedTestFile(t, fileSize, interruptAt)
60+
serverURLs, expectedHash, cleanup := serveInterruptedTestFiles(t, fileSize, interruptAt, 1)
5961
defer cleanup()
6062

6163
hasher := sha256.New()
6264

63-
err := DownloadURL(ctx, serverURL, io.Discard, WithExpectedHash(hasher, expectedHash))
65+
err := DownloadURLMultipart(ctx, serverURLs, io.Discard, WithExpectedHash(hasher, expectedHash))
6466
r.NoError(err)
6567
}
6668

67-
// derrived from https://github.com/vansante/go-dl-stream/blob/e29aef86498f37d3506126bc258193f1c913ea55/download_test.go#L166
68-
func serveInterruptedTestFile(t *testing.T, fileSize, interruptAt int64) (serverURL string, sha256Hash []byte, cleanup func()) {
69-
rndFile, err := os.CreateTemp(os.TempDir(), "random_file_*.rnd")
70-
assert.NoError(t, err)
71-
filePath := rndFile.Name()
69+
func TestDownloadMultipart(t *testing.T) {
70+
r := require.New(t)
71+
ctx := context.Background()
72+
73+
serverURLs, expectedHash, cleanup := serveInterruptedTestFiles(t, fileSize, 0, 10)
74+
defer cleanup()
7275

7376
hasher := sha256.New()
74-
_, err = io.Copy(io.MultiWriter(hasher, rndFile), io.LimitReader(rand.Reader, fileSize))
75-
assert.NoError(t, err)
76-
assert.NoError(t, rndFile.Close())
7777

78-
mux := http.NewServeMux()
79-
mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {
80-
log.Printf("Serving random interrupted file (size: %d, interuptAt: %d), Range: %s", fileSize, interruptAt, request.Header.Get(rangeHeader))
78+
err := DownloadURLMultipart(ctx, serverURLs, io.Discard, WithExpectedHash(hasher, expectedHash))
79+
r.NoError(err)
80+
}
81+
82+
func TestDownloadMultipartInterrupted(t *testing.T) {
83+
r := require.New(t)
84+
ctx := context.Background()
85+
86+
serverURLs, expectedHash, cleanup := serveInterruptedTestFiles(t, fileSize, interruptAt, 10)
87+
defer cleanup()
88+
89+
hasher := sha256.New()
8190

82-
http.ServeFile(&interruptibleHTTPWriter{
83-
ResponseWriter: writer,
84-
writer: writer,
85-
interruptAt: interruptAt,
86-
}, request, filePath)
91+
err := DownloadURLMultipart(ctx, serverURLs, io.Discard, WithExpectedHash(hasher, expectedHash))
92+
r.NoError(err)
93+
}
94+
95+
func TestErrNonRetryable(t *testing.T) {
96+
err := NonRetryableWrapf("test")
97+
require.True(t, errors.Is(err, ErrNonRetryable{}))
98+
}
8799

88-
})
100+
// derrived from https://github.com/vansante/go-dl-stream/blob/e29aef86498f37d3506126bc258193f1c913ea55/download_test.go#L166
101+
func serveInterruptedTestFiles(t *testing.T, fileSize, interruptAt int64, parts int) ([]string, []byte, func()) {
102+
mux := http.NewServeMux()
89103
server := httptest.NewServer(mux)
104+
hasher := sha256.New()
105+
filePaths := []string{}
106+
urls := []string{}
107+
108+
for i := 0; i < parts; i++ {
109+
rndFile, err := os.CreateTemp(os.TempDir(), "random_file_*.rnd")
110+
assert.NoError(t, err)
111+
filePath := rndFile.Name()
112+
filePaths = append(filePaths, filePath)
113+
filePathBase := filepath.Base(filePath)
90114

91-
return server.URL, hasher.Sum(nil), func() {
92-
_ = os.Remove(filePath)
115+
_, err = io.Copy(io.MultiWriter(hasher, rndFile), io.LimitReader(rand.Reader, fileSize))
116+
assert.NoError(t, err)
117+
assert.NoError(t, rndFile.Close())
118+
119+
mux.HandleFunc(filePath, func(writer http.ResponseWriter, request *http.Request) {
120+
log.Printf("Serving random interrupted file %s (size: %d, interuptAt: %d), Range: %s", filePathBase, fileSize, interruptAt, request.Header.Get(rangeHeader))
121+
122+
http.ServeFile(&interruptibleHTTPWriter{
123+
ResponseWriter: writer,
124+
writer: writer,
125+
interruptAt: interruptAt,
126+
}, request, filePath)
127+
128+
})
129+
urls = append(urls, server.URL+filePath)
130+
131+
}
132+
133+
return urls, hasher.Sum(nil), func() {
134+
for _, path := range filePaths {
135+
_ = os.Remove(path)
136+
}
93137
}
94138
}
95139

errors.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ func (e ErrNonRetryable) Unwrap() error {
2828
return e.inner
2929
}
3030

31+
func (e ErrNonRetryable) Is(target error) bool {
32+
_, ok := target.(ErrNonRetryable)
33+
return ok
34+
}
35+
3136
func NonRetryableWrap(err error) error {
3237
return ErrNonRetryable{inner: err}
3338
}

0 commit comments

Comments
 (0)