Skip to content

Commit 2c8ea61

Browse files
cstocktonChris Stockton
andauthored
feat: introduce request-scoped background tasks & async mail sending (#2126)
## Summary Introduce a new `apitask` package for running background tasks within the lifecycle of an API request. We then use this feature in the mailer client to send emails in the background. Before the request returns to the caller we first wait for all in-flight tasks to exit. **Summary:** * Define a `Task` interface for light background work. * Add `Run(ctx, task)` to schedule tasks asynchronously when supported, or fall back to synchronous execution otherwise. * Add `Wait(ctx)` to block until all request-scoped tasks finish * Add `Middleware` to wrap HTTP handlers, attaching a request-scoped worker to the context and ensuring all tasks complete before returning a response. * Add a check for `EmailBackgroundSending` in mailer package, when present wrap the mail client in a `backgroundMailClient`. When `Mail` is called it will send a `mailer.Task` to `apitasks.Run` to send the email in the background using the wrapped `MailClient`. * Add config `GOTRUE_MAILER_EMAIL_BACKGROUND_SENDING` (def `false`) - on/off switch. I will follow up with unit tests once the team has had an opportunity to review. --------- Co-authored-by: Chris Stockton <[email protected]>
1 parent 68c40a6 commit 2c8ea61

File tree

9 files changed

+326
-22
lines changed

9 files changed

+326
-22
lines changed

cmd/serve_cmd.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,17 @@ func serve(ctx context.Context) {
5353
opts := []api.Option{
5454
api.NewLimiterOptions(config),
5555
}
56-
a := api.NewAPIWithVersion(config, db, utilities.Version, opts...)
57-
ah := reloader.NewAtomicHandler(a)
58-
logrus.WithField("version", a.Version()).Infof("GoTrue API started on: %s", addr)
5956

6057
baseCtx, baseCancel := context.WithCancel(context.Background())
6158
defer baseCancel()
6259

60+
var wg sync.WaitGroup
61+
defer wg.Wait() // Do not return to caller until this goroutine is done.
62+
63+
a := api.NewAPIWithVersion(config, db, utilities.Version, opts...)
64+
ah := reloader.NewAtomicHandler(a)
65+
logrus.WithField("version", a.Version()).Infof("GoTrue API started on: %s", addr)
66+
6367
httpSrv := &http.Server{
6468
Addr: addr,
6569
Handler: ah,
@@ -70,9 +74,6 @@ func serve(ctx context.Context) {
7074
}
7175
log := logrus.WithField("component", "api")
7276

73-
var wg sync.WaitGroup
74-
defer wg.Wait() // Do not return to caller until this goroutine is done.
75-
7677
if watchDir != "" {
7778
wg.Add(1)
7879
go func() {
@@ -98,7 +99,10 @@ func serve(ctx context.Context) {
9899

99100
<-ctx.Done()
100101

101-
defer baseCancel() // close baseContext
102+
// This must be done after httpSrv exits, otherwise you may potentially
103+
// have 1 or more inflight http requests blocked until the shutdownCtx
104+
// is canceled.
105+
defer baseCancel()
102106

103107
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), time.Minute)
104108
defer shutdownCancel()

internal/api/api.go

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/sebest/xff"
1010
"github.com/sirupsen/logrus"
1111
"github.com/supabase/auth/internal/api/apierrors"
12+
"github.com/supabase/auth/internal/api/apitask"
1213
"github.com/supabase/auth/internal/api/oauthserver"
1314
"github.com/supabase/auth/internal/conf"
1415
"github.com/supabase/auth/internal/hooks/hookshttp"
@@ -37,10 +38,11 @@ type API struct {
3738
config *conf.GlobalConfiguration
3839
version string
3940

40-
hooksMgr *v0hooks.Manager
41-
hibpClient *hibp.PwnedClient
42-
oauthServer *oauthserver.Server
43-
tokenService *tokens.Service
41+
hooksMgr *v0hooks.Manager
42+
hibpClient *hibp.PwnedClient
43+
oauthServer *oauthserver.Server
44+
mailerClientFunc func() mailer.MailClient
45+
tokenService *tokens.Service
4446

4547
// overrideTime can be used to override the clock used by handlers. Should only be used in tests!
4648
overrideTime func() time.Time
@@ -98,6 +100,11 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
98100
if api.limiterOpts == nil {
99101
api.limiterOpts = NewLimiterOptions(globalConfig)
100102
}
103+
if api.mailerClientFunc == nil {
104+
api.mailerClientFunc = func() mailer.MailClient {
105+
return mailer.NewMailClient(globalConfig)
106+
}
107+
}
101108
if api.hooksMgr == nil {
102109
httpDr := hookshttp.New()
103110
pgfuncDr := hookspgfunc.New(db)
@@ -157,6 +164,10 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
157164
r.UseBypass(api.databaseCleanup(cleanup))
158165
}
159166

167+
if globalConfig.Mailer.EmailBackgroundSending {
168+
r.UseBypass(apitask.Middleware)
169+
}
170+
160171
r.Get("/health", api.HealthCheck)
161172
r.Get("/.well-known/jwks.json", api.Jwks)
162173

@@ -366,7 +377,7 @@ func (a *API) HealthCheck(w http.ResponseWriter, r *http.Request) error {
366377
// Mailer returns NewMailer with the current tenant config
367378
func (a *API) Mailer() mailer.Mailer {
368379
config := a.config
369-
return mailer.NewMailer(config)
380+
return mailer.NewMailerWithClient(config, a.mailerClientFunc())
370381
}
371382

372383
// ServeHTTP implements the http.Handler interface by passing the request along

internal/api/apitask/apitask.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package apitask
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"net/http"
8+
"sync"
9+
10+
"github.com/sirupsen/logrus"
11+
"github.com/supabase/auth/internal/api/apierrors"
12+
"github.com/supabase/auth/internal/observability"
13+
)
14+
15+
// ErrTask is the base of all errors originating from apitasks.
16+
var ErrTask = errors.New("apitask")
17+
18+
// Middleware wraps next with an http.Handler which adds apitasks handling
19+
// to the request context and waits for all tasks to exit before returning.
20+
func Middleware(next http.Handler) http.Handler {
21+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
22+
r = r.WithContext(With(r.Context()))
23+
defer Wait(r.Context())
24+
25+
next.ServeHTTP(w, r)
26+
})
27+
}
28+
29+
// Task is implemented by objects which may be ran in the background.
30+
type Task interface {
31+
32+
// Type return a basic name for a task. It is not expected to be consistent
33+
// with the underlying type, but it should be low cardinality.
34+
Type() string
35+
36+
// Run should run this task.
37+
Run(context.Context) error
38+
}
39+
40+
// Run will run a request-scoped background task in a separate goroutine
41+
// immediately if the current context supports it. Otherwise it makes an
42+
// immediate blocking call to task.Run(ctx).
43+
//
44+
// It is invalid to call Run within a tasks Run method.
45+
func Run(ctx context.Context, task Task) error {
46+
wrk, ok := from(ctx)
47+
if !ok {
48+
return task.Run(ctx)
49+
}
50+
return wrk.run(ctx, task)
51+
}
52+
53+
// Wait will wait for all currently running request-scoped background tasks to
54+
// complete before returning.
55+
func Wait(ctx context.Context) {
56+
wrk, ok := from(ctx)
57+
if !ok {
58+
return
59+
}
60+
wrk.wait()
61+
}
62+
63+
// With sets up the given context for adding request-scoped background tasks.
64+
func With(ctx context.Context) context.Context {
65+
wrk, ok := from(ctx)
66+
if !ok {
67+
wrk = &requestWorker{}
68+
}
69+
return context.WithValue(ctx, ctxKey, wrk)
70+
}
71+
72+
var ctxKey = new(int)
73+
74+
func from(ctx context.Context) (*requestWorker, bool) {
75+
if st, ok := ctx.Value(ctxKey).(*requestWorker); ok && st != nil {
76+
return st, true
77+
}
78+
return nil, false
79+
}
80+
81+
type requestWorker struct {
82+
mu sync.Mutex
83+
wg sync.WaitGroup
84+
done bool
85+
}
86+
87+
func (o *requestWorker) wait() {
88+
o.mu.Lock()
89+
o.done = true
90+
o.mu.Unlock()
91+
92+
o.wg.Wait()
93+
}
94+
95+
func (o *requestWorker) run(ctx context.Context, task Task) error {
96+
o.mu.Lock()
97+
defer o.mu.Unlock()
98+
if o.done {
99+
err := fmt.Errorf(
100+
"%w: unable to run tasks after a call to Wait", ErrTask)
101+
return apierrors.NewInternalServerError(
102+
"failed to run task").WithInternalError(err)
103+
}
104+
105+
o.wg.Add(1)
106+
go func() {
107+
defer o.wg.Done()
108+
109+
if err := task.Run(ctx); err != nil {
110+
typ := task.Type()
111+
err = fmt.Errorf("apitask: error running %q: %w", typ, err)
112+
113+
le := observability.GetLogEntryFromContext(ctx).Entry
114+
le.WithFields(logrus.Fields{
115+
"action": "apitask",
116+
"task_type": typ,
117+
}).WithError(err).Error(err)
118+
}
119+
}()
120+
return nil
121+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package apitask
2+
3+
import (
4+
"context"
5+
"sync/atomic"
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
type taskFunc struct {
13+
typ string
14+
fn func(context.Context) error
15+
}
16+
17+
func (o *taskFunc) Type() string { return o.typ }
18+
19+
func (o *taskFunc) Run(ctx context.Context) error { return o.fn(ctx) }
20+
21+
func taskFn(typ string, fn func(context.Context) error) Task {
22+
return &taskFunc{typ: typ, fn: fn}
23+
}
24+
25+
func TestRequestWorker(t *testing.T) {
26+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
27+
defer cancel()
28+
29+
t.Run("RunTasks", func(t *testing.T) {
30+
{
31+
rw, ok := from(ctx)
32+
require.False(t, ok, "request worker must not found in context")
33+
require.Nil(t, rw, "request worker must be nil")
34+
}
35+
36+
withCtx := With(ctx)
37+
{
38+
rw, ok := from(withCtx)
39+
require.True(t, ok, "request worker not found in context")
40+
require.NotNil(t, rw, "request worker was nil")
41+
42+
withCtxDupe := With(withCtx)
43+
sameRw, ok := from(withCtxDupe)
44+
require.True(t, ok, "request worker not found in context")
45+
require.True(t, rw == sameRw, "request worker should be created only once")
46+
}
47+
})
48+
49+
t.Run("RunTasks", func(t *testing.T) {
50+
withCtx := With(ctx)
51+
52+
calls := new(atomic.Int64)
53+
expCalls := 0
54+
for range 16 {
55+
expCalls++
56+
task := taskFn("test.run", func(ctx context.Context) error {
57+
calls.Add(1)
58+
return nil
59+
})
60+
err := Run(withCtx, task)
61+
require.NoError(t, err)
62+
}
63+
64+
{
65+
Wait(withCtx)
66+
67+
gotCalls := int(calls.Load())
68+
require.Equal(t, expCalls, gotCalls)
69+
}
70+
})
71+
}

internal/api/options.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"github.com/didip/tollbooth/v5"
77
"github.com/didip/tollbooth/v5/limiter"
88
"github.com/supabase/auth/internal/conf"
9+
"github.com/supabase/auth/internal/mailer"
910
"github.com/supabase/auth/internal/ratelimit"
1011
"github.com/supabase/auth/internal/tokens"
1112
)
@@ -14,6 +15,12 @@ type Option interface {
1415
apply(*API)
1516
}
1617

18+
type MailerOptions struct {
19+
MailerClientFunc func() mailer.MailClient
20+
}
21+
22+
func (mo *MailerOptions) apply(a *API) { a.mailerClientFunc = mo.MailerClientFunc }
23+
1724
type LimiterOptions struct {
1825
Email ratelimit.Limiter
1926
Phone ratelimit.Limiter

internal/conf/configuration.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,7 @@ type MailerConfiguration struct {
445445
ExternalHosts []string `json:"external_hosts" split_words:"true"`
446446

447447
// EXPERIMENTAL: May be removed in a future release.
448+
EmailBackgroundSending bool `json:"email_background_sending" split_words:"true" default:"false"`
448449
EmailValidationExtended bool `json:"email_validation_extended" split_words:"true" default:"false"`
449450
EmailValidationServiceURL string `json:"email_validation_service_url" split_words:"true"`
450451
EmailValidationServiceHeaders string `json:"email_validation_service_headers" split_words:"true"`

internal/conf/configuration_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,17 @@ func TestGlobal(t *testing.T) {
5959
assert.Equal(t, time.Hour, gc.RateLimitEmailSent.OverTime)
6060
}
6161

62+
{
63+
gc, err := LoadGlobal("")
64+
require.NoError(t, err)
65+
assert.Equal(t, false, gc.Mailer.EmailBackgroundSending)
66+
67+
os.Setenv("GOTRUE_MAILER_EMAIL_BACKGROUND_SENDING", "true")
68+
gc, err = LoadGlobal("")
69+
require.NoError(t, err)
70+
assert.Equal(t, true, gc.Mailer.EmailBackgroundSending)
71+
}
72+
6273
{
6374
hdrs := gc.Mailer.GetEmailValidationServiceHeaders()
6475
assert.Equal(t, 1, len(hdrs["apikey"]))

0 commit comments

Comments
 (0)