|
| 1 | +package apitask |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "errors" |
| 6 | + "fmt" |
| 7 | + "net/http" |
| 8 | + "sync" |
| 9 | + |
| 10 | + "github.com/sirupsen/logrus" |
| 11 | + "github.com/supabase/auth/internal/api/apierrors" |
| 12 | + "github.com/supabase/auth/internal/observability" |
| 13 | +) |
| 14 | + |
| 15 | +// ErrTask is the base of all errors originating from apitasks. |
| 16 | +var ErrTask = errors.New("apitask") |
| 17 | + |
| 18 | +// Middleware wraps next with an http.Handler which adds apitasks handling |
| 19 | +// to the request context and waits for all tasks to exit before returning. |
| 20 | +func Middleware(next http.Handler) http.Handler { |
| 21 | + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 22 | + r = r.WithContext(With(r.Context())) |
| 23 | + defer Wait(r.Context()) |
| 24 | + |
| 25 | + next.ServeHTTP(w, r) |
| 26 | + }) |
| 27 | +} |
| 28 | + |
| 29 | +// Task is implemented by objects which may be ran in the background. |
| 30 | +type Task interface { |
| 31 | + |
| 32 | + // Type return a basic name for a task. It is not expected to be consistent |
| 33 | + // with the underlying type, but it should be low cardinality. |
| 34 | + Type() string |
| 35 | + |
| 36 | + // Run should run this task. |
| 37 | + Run(context.Context) error |
| 38 | +} |
| 39 | + |
| 40 | +// Run will run a request-scoped background task in a separate goroutine |
| 41 | +// immediately if the current context supports it. Otherwise it makes an |
| 42 | +// immediate blocking call to task.Run(ctx). |
| 43 | +// |
| 44 | +// It is invalid to call Run within a tasks Run method. |
| 45 | +func Run(ctx context.Context, task Task) error { |
| 46 | + wrk, ok := from(ctx) |
| 47 | + if !ok { |
| 48 | + return task.Run(ctx) |
| 49 | + } |
| 50 | + return wrk.run(ctx, task) |
| 51 | +} |
| 52 | + |
| 53 | +// Wait will wait for all currently running request-scoped background tasks to |
| 54 | +// complete before returning. |
| 55 | +func Wait(ctx context.Context) { |
| 56 | + wrk, ok := from(ctx) |
| 57 | + if !ok { |
| 58 | + return |
| 59 | + } |
| 60 | + wrk.wait() |
| 61 | +} |
| 62 | + |
| 63 | +// With sets up the given context for adding request-scoped background tasks. |
| 64 | +func With(ctx context.Context) context.Context { |
| 65 | + wrk, ok := from(ctx) |
| 66 | + if !ok { |
| 67 | + wrk = &requestWorker{} |
| 68 | + } |
| 69 | + return context.WithValue(ctx, ctxKey, wrk) |
| 70 | +} |
| 71 | + |
| 72 | +var ctxKey = new(int) |
| 73 | + |
| 74 | +func from(ctx context.Context) (*requestWorker, bool) { |
| 75 | + if st, ok := ctx.Value(ctxKey).(*requestWorker); ok && st != nil { |
| 76 | + return st, true |
| 77 | + } |
| 78 | + return nil, false |
| 79 | +} |
| 80 | + |
| 81 | +type requestWorker struct { |
| 82 | + mu sync.Mutex |
| 83 | + wg sync.WaitGroup |
| 84 | + done bool |
| 85 | +} |
| 86 | + |
| 87 | +func (o *requestWorker) wait() { |
| 88 | + o.mu.Lock() |
| 89 | + o.done = true |
| 90 | + o.mu.Unlock() |
| 91 | + |
| 92 | + o.wg.Wait() |
| 93 | +} |
| 94 | + |
| 95 | +func (o *requestWorker) run(ctx context.Context, task Task) error { |
| 96 | + o.mu.Lock() |
| 97 | + defer o.mu.Unlock() |
| 98 | + if o.done { |
| 99 | + err := fmt.Errorf( |
| 100 | + "%w: unable to run tasks after a call to Wait", ErrTask) |
| 101 | + return apierrors.NewInternalServerError( |
| 102 | + "failed to run task").WithInternalError(err) |
| 103 | + } |
| 104 | + |
| 105 | + o.wg.Add(1) |
| 106 | + go func() { |
| 107 | + defer o.wg.Done() |
| 108 | + |
| 109 | + if err := task.Run(ctx); err != nil { |
| 110 | + typ := task.Type() |
| 111 | + err = fmt.Errorf("apitask: error running %q: %w", typ, err) |
| 112 | + |
| 113 | + le := observability.GetLogEntryFromContext(ctx).Entry |
| 114 | + le.WithFields(logrus.Fields{ |
| 115 | + "action": "apitask", |
| 116 | + "task_type": typ, |
| 117 | + }).WithError(err).Error(err) |
| 118 | + } |
| 119 | + }() |
| 120 | + return nil |
| 121 | +} |
0 commit comments