Skip to content

Commit b9e9a70

Browse files
committed
feat: separate rate limiting from timeouts
1 parent ba07e0c commit b9e9a70

File tree

9 files changed

+72
-91
lines changed

9 files changed

+72
-91
lines changed

clientutil/clientutil.go

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,12 @@
33
package clientutil
44

55
import (
6-
"context"
76
"log/slog"
87
"net/http"
98
"sync"
109
"time"
1110

1211
"github.com/gregjones/httpcache"
13-
"golang.org/x/time/rate"
1412
)
1513

1614
type Middleware func(http.RoundTripper) http.RoundTripper
@@ -36,34 +34,6 @@ func WithCache() Middleware {
3634
}
3735
}
3836

39-
func WithTimeout(timeout time.Duration) Middleware {
40-
if timeout == 0 {
41-
return Passthrough
42-
}
43-
return func(next http.RoundTripper) http.RoundTripper {
44-
return RoundTripFunc(func(r *http.Request) (*http.Response, error) {
45-
ctx, cancel := context.WithTimeout(r.Context(), timeout)
46-
defer cancel()
47-
return next.RoundTrip(r.WithContext(ctx))
48-
})
49-
}
50-
}
51-
52-
func WithRateLimit(interval time.Duration) Middleware {
53-
if interval == 0 {
54-
return Passthrough
55-
}
56-
return func(next http.RoundTripper) http.RoundTripper {
57-
limiter := rate.NewLimiter(rate.Every(interval), 1)
58-
return RoundTripFunc(func(r *http.Request) (*http.Response, error) {
59-
if err := limiter.Wait(r.Context()); err != nil {
60-
return nil, err
61-
}
62-
return next.RoundTrip(r)
63-
})
64-
}
65-
}
66-
6737
func WithLogging(logger *slog.Logger) Middleware {
6838
return func(next http.RoundTripper) http.RoundTripper {
6939
return RoundTripFunc(func(r *http.Request) (*http.Response, error) {

cmd/internal/wrtagflag/wrtagflag.go

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"go.senan.xyz/wrtag/notifications"
2323
"go.senan.xyz/wrtag/pathformat"
2424
"go.senan.xyz/wrtag/researchlink"
25+
"golang.org/x/time/rate"
2526

2627
_ "go.senan.xyz/wrtag/addon/lyrics"
2728
_ "go.senan.xyz/wrtag/addon/musicdesc"
@@ -83,10 +84,20 @@ func Config() *wrtag.Config {
8384
flag.Var(&tagConfigParser{&cfg.TagConfig}, "tag-config", "Specify tag keep and drop rules when writing new tag revisions (see [Tagging](#tagging)) (stackable)")
8485

8586
flag.StringVar(&cfg.MusicBrainzClient.BaseURL, "mb-base-url", `https://musicbrainz.org/ws/2/`, "MusicBrainz base URL")
86-
flag.DurationVar(&cfg.MusicBrainzClient.RateLimit, "mb-rate-limit", 1*time.Second, "MusicBrainz rate limit duration")
87+
88+
cfg.MusicBrainzClient.Limiter = rate.NewLimiter(rate.Every(1*time.Second), 1)
89+
flag.Var(&rateLimitParser{cfg.MusicBrainzClient.Limiter}, "mb-rate-limit", "MusicBrainz rate limit duration")
90+
91+
cfg.MusicBrainzClient.HTTPClient = &http.Client{Timeout: 30 * time.Second}
8792

8893
flag.StringVar(&cfg.CoverArtArchiveClient.BaseURL, "caa-base-url", `https://coverartarchive.org/`, "CoverArtArchive base URL")
89-
flag.DurationVar(&cfg.CoverArtArchiveClient.RateLimit, "caa-rate-limit", 0, "CoverArtArchive rate limit duration")
94+
95+
cfg.CoverArtArchiveClient.Limiter = rate.NewLimiter(rate.Inf, 0)
96+
flag.Var(&rateLimitParser{cfg.CoverArtArchiveClient.Limiter}, "caa-rate-limit", "CoverArtArchive rate limit duration")
97+
98+
cfg.CoverArtArchiveClient.HTTPClient = clientutil.Wrap(&http.Client{Timeout: 30 * time.Second},
99+
clientutil.WithCache(),
100+
)
90101

91102
flag.BoolVar(&cfg.UpgradeCover, "cover-upgrade", false, "Fetch new cover art even if it exists locally")
92103

@@ -311,3 +322,26 @@ func (a addonsParser) String() string {
311322
}
312323
return strings.Join(parts, ", ")
313324
}
325+
326+
type rateLimitParser struct {
327+
l *rate.Limiter
328+
}
329+
330+
func (rl rateLimitParser) Set(value string) error {
331+
dur, err := time.ParseDuration(value)
332+
if err != nil {
333+
return err
334+
}
335+
*rl.l = *rate.NewLimiter(rate.Every(dur), 1)
336+
return nil
337+
}
338+
func (rl *rateLimitParser) String() string {
339+
if rl.l == nil {
340+
return ""
341+
}
342+
dur := time.Duration((1.0 / float64(rl.l.Limit())) * float64(time.Second))
343+
if dur == 0 {
344+
return ""
345+
}
346+
return dur.String()
347+
}

lyrics/lyrics.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ import (
66
"context"
77
"errors"
88
"fmt"
9+
"net/http"
910
"strings"
1011
"time"
1112

1213
"golang.org/x/net/html"
14+
"golang.org/x/time/rate"
1315
)
1416

1517
type Source interface {
@@ -19,11 +21,11 @@ type Source interface {
1921
func NewSource(name string) (Source, error) {
2022
switch name {
2123
case "genius":
22-
return &Genius{RateLimit: 500 * time.Millisecond}, nil
24+
return &Genius{&http.Client{}, rate.NewLimiter(rate.Every(500*time.Millisecond), 1)}, nil
2325
case "musixmatch":
24-
return &Musixmatch{RateLimit: 500 * time.Millisecond}, nil
26+
return &Musixmatch{&http.Client{}, rate.NewLimiter(rate.Every(500*time.Millisecond), 1)}, nil
2527
case "lrclib":
26-
return &LRCLib{RateLimit: 100 * time.Millisecond}, nil
28+
return &LRCLib{&http.Client{}, rate.NewLimiter(rate.Every(100*time.Millisecond), 1)}, nil
2729
default:
2830
return nil, errors.New("unknown source")
2931
}

lyrics/lyrics_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/stretchr/testify/assert"
1010
"github.com/stretchr/testify/require"
1111
"go.senan.xyz/wrtag/lyrics"
12+
"golang.org/x/time/rate"
1213
)
1314

1415
//go:embed testdata
@@ -19,6 +20,7 @@ func TestMusixmatch(t *testing.T) {
1920

2021
var src lyrics.Musixmatch
2122
src.HTTPClient = fsClient(responses, "testdata/musixmatch")
23+
src.Limiter = rate.NewLimiter(rate.Inf, 0)
2224

2325
resp, err := src.Search(t.Context(), "The Fall", "Wings", 0)
2426
require.NoError(t, err)
@@ -41,6 +43,7 @@ func TestGenius(t *testing.T) {
4143

4244
var src lyrics.Genius
4345
src.HTTPClient = fsClient(responses, "testdata/genius")
46+
src.Limiter = rate.NewLimiter(rate.Inf, 0)
4447

4548
resp, err := src.Search(t.Context(), "the fall", "totally wired", 0)
4649
require.NoError(t, err)
@@ -64,6 +67,7 @@ func TestGeniusLineBreak(t *testing.T) {
6467

6568
var src lyrics.Genius
6669
src.HTTPClient = fsClient(responses, "testdata/genius")
70+
src.Limiter = rate.NewLimiter(rate.Inf, 0)
6771

6872
resp, err := src.Search(t.Context(), "pink floyd", "breathe in the air", 0)
6973
require.NoError(t, err)

lyrics/source_genius.go

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@ import (
66
"net/http"
77
"net/url"
88
"strings"
9-
"sync"
109
"time"
1110

1211
"github.com/andybalholm/cascadia"
13-
"go.senan.xyz/wrtag/clientutil"
1412
"golang.org/x/net/html"
13+
"golang.org/x/time/rate"
1514
)
1615

1716
var geniusBaseURL = `https://genius.com`
@@ -32,18 +31,14 @@ var geniusEsc = strings.NewReplacer(
3231
)
3332

3433
type Genius struct {
35-
RateLimit time.Duration
36-
37-
initOnce sync.Once
3834
HTTPClient *http.Client
35+
Limiter *rate.Limiter
3936
}
4037

4138
func (g *Genius) Search(ctx context.Context, artist, song string, duration time.Duration) (string, error) {
42-
g.initOnce.Do(func() {
43-
g.HTTPClient = clientutil.Wrap(g.HTTPClient, clientutil.Chain(
44-
clientutil.WithRateLimit(g.RateLimit),
45-
))
46-
})
39+
if err := g.Limiter.Wait(ctx); err != nil {
40+
return "", err
41+
}
4742

4843
// use genius case rules to miminise redirects
4944
page := fmt.Sprintf("%s-%s-lyrics", artist, song)

lyrics/source_lrclib.go

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,22 @@ import (
66
"fmt"
77
"net/http"
88
"net/url"
9-
"sync"
109
"time"
1110

12-
"go.senan.xyz/wrtag/clientutil"
11+
"golang.org/x/time/rate"
1312
)
1413

1514
var lrclibBaseURL = `https://lrclib.net/api/get`
1615

1716
type LRCLib struct {
18-
RateLimit time.Duration
19-
20-
initOnce sync.Once
2117
HTTPClient *http.Client
18+
Limiter *rate.Limiter
2219
}
2320

2421
func (l *LRCLib) Search(ctx context.Context, artist, song string, duration time.Duration) (string, error) {
25-
l.initOnce.Do(func() {
26-
l.HTTPClient = clientutil.Wrap(l.HTTPClient, clientutil.Chain(
27-
clientutil.WithRateLimit(l.RateLimit),
28-
))
29-
})
22+
if err := l.Limiter.Wait(ctx); err != nil {
23+
return "", err
24+
}
3025

3126
u, _ := url.Parse(lrclibBaseURL)
3227
q := u.Query()

lyrics/source_musixmatch.go

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@ import (
66
"net/http"
77
"net/url"
88
"strings"
9-
"sync"
109
"time"
1110

1211
"github.com/andybalholm/cascadia"
13-
"go.senan.xyz/wrtag/clientutil"
1412
"golang.org/x/net/html"
13+
"golang.org/x/time/rate"
1514
)
1615

1716
var musixmatchBaseURL = `https://www.musixmatch.com/lyrics`
@@ -33,18 +32,14 @@ var musixmatchEsc = strings.NewReplacer(
3332
)
3433

3534
type Musixmatch struct {
36-
RateLimit time.Duration
37-
38-
initOnce sync.Once
3935
HTTPClient *http.Client
36+
Limiter *rate.Limiter
4037
}
4138

4239
func (mm *Musixmatch) Search(ctx context.Context, artist, song string, duration time.Duration) (string, error) {
43-
mm.initOnce.Do(func() {
44-
mm.HTTPClient = clientutil.Wrap(mm.HTTPClient, clientutil.Chain(
45-
clientutil.WithRateLimit(mm.RateLimit),
46-
))
47-
})
40+
if err := mm.Limiter.Wait(ctx); err != nil {
41+
return "", err
42+
}
4843

4944
url, _ := url.Parse(musixmatchBaseURL)
5045
url = url.JoinPath(musixmatchEsc.Replace(artist))

musicbrainz/coverartarchive.go

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,14 @@ import (
66
"errors"
77
"fmt"
88
"net/http"
9-
"sync"
10-
"time"
119

12-
"go.senan.xyz/wrtag/clientutil"
10+
"golang.org/x/time/rate"
1311
)
1412

1513
type CAAClient struct {
16-
BaseURL string
17-
RateLimit time.Duration
18-
19-
initOnce sync.Once
14+
BaseURL string
2015
HTTPClient *http.Client
16+
Limiter *rate.Limiter
2117
}
2218

2319
func (c *CAAClient) GetCoverURL(ctx context.Context, release *Release) (string, error) {
@@ -74,13 +70,9 @@ type caaResponse struct {
7470
}
7571

7672
func (c *CAAClient) request(ctx context.Context, r *http.Request, dest any) error {
77-
c.initOnce.Do(func() {
78-
c.HTTPClient = clientutil.Wrap(c.HTTPClient, clientutil.Chain(
79-
clientutil.WithCache(),
80-
clientutil.WithRateLimit(c.RateLimit),
81-
clientutil.WithTimeout(30*time.Second),
82-
))
83-
})
73+
if err := c.Limiter.Wait(ctx); err != nil {
74+
return err
75+
}
8476

8577
r = r.WithContext(ctx)
8678
resp, err := c.HTTPClient.Do(r)

musicbrainz/musicbrainz.go

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,19 @@ import (
1515
"slices"
1616
"strconv"
1717
"strings"
18-
"sync"
1918
"time"
2019
"unicode"
2120

2221
"github.com/araddon/dateparse"
23-
"go.senan.xyz/wrtag/clientutil"
22+
"golang.org/x/time/rate"
2423
)
2524

2625
var ErrNoResults = errors.New("no results")
2726

2827
type MBClient struct {
29-
BaseURL string
30-
RateLimit time.Duration
31-
32-
initOnce sync.Once
28+
BaseURL string
3329
HTTPClient *http.Client
30+
Limiter *rate.Limiter
3431
}
3532

3633
func (c *MBClient) GetRelease(ctx context.Context, mbid string) (*Release, error) {
@@ -146,12 +143,9 @@ func (c *MBClient) SearchRelease(ctx context.Context, q ReleaseQuery) (*Release,
146143
}
147144

148145
func (c *MBClient) request(ctx context.Context, r *http.Request, dest any) error {
149-
c.initOnce.Do(func() {
150-
c.HTTPClient = clientutil.Wrap(c.HTTPClient, clientutil.Chain(
151-
clientutil.WithRateLimit(c.RateLimit),
152-
clientutil.WithTimeout(30*time.Second),
153-
))
154-
})
146+
if err := c.Limiter.Wait(ctx); err != nil {
147+
return err
148+
}
155149

156150
r = r.WithContext(ctx)
157151
resp, err := c.HTTPClient.Do(r)

0 commit comments

Comments
 (0)