Skip to content

Commit bd80df8

Browse files
cstocktonChris Stockton
andauthored
feat: add after-user-created hook (#2169)
This PR implements the `after-user-created` hook which runs whenever a new user has been successfully created. Summary: - Add `triggerAfterUserCreated` method to the `*API` object in `internal/api/hooks.go` - Update user creation paths to call `triggerAfterUserCreated` after a new user is persisted: - internal/api/anonymous.go - internal/api/external.go - internal/api/invite.go - internal/api/mail.go - internal/api/signup.go - internal/api/samlacs.go - internal/api/token_oidc.go - internal/api/web3.go - Extend `createAccountFromExternalIdentity` to return an `AccountLinkingDecision` to detect newly created accounts - Add full end-to-end verification of the new hook in `internal/api/e2e_test.go` Co-authored-by: Chris Stockton <[email protected]>
1 parent 3511eb4 commit bd80df8

File tree

12 files changed

+184
-54
lines changed

12 files changed

+184
-54
lines changed

internal/api/anonymous.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error {
5353
if err != nil {
5454
return apierrors.NewInternalServerError("Database error creating anonymous user").WithInternalError(err)
5555
}
56+
if err := a.triggerAfterUserCreated(r, db, newUser); err != nil {
57+
return err
58+
}
5659

5760
metering.RecordLogin(metering.LoginTypeAnonymous, newUser.ID, nil)
5861
return sendJSON(w, http.StatusOK, token)

internal/api/apitask/apitask.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,19 @@ type Task interface {
3939
Run(context.Context) error
4040
}
4141

42+
type taskFunc struct {
43+
typ string
44+
fn func(context.Context) error
45+
}
46+
47+
func (o *taskFunc) Type() string { return o.typ }
48+
49+
func (o *taskFunc) Run(ctx context.Context) error { return o.fn(ctx) }
50+
51+
func Func(typ string, fn func(context.Context) error) Task {
52+
return &taskFunc{typ: typ, fn: fn}
53+
}
54+
4255
// Run will run a request-scoped background task in a separate goroutine
4356
// immediately if the current context supports it. Otherwise it makes an
4457
// immediate blocking call to task.Run(ctx).

internal/api/apitask/apitask_test.go

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,6 @@ import (
1212
"github.com/stretchr/testify/require"
1313
)
1414

15-
type taskFunc struct {
16-
typ string
17-
fn func(context.Context) error
18-
}
19-
20-
func (o *taskFunc) Type() string { return o.typ }
21-
22-
func (o *taskFunc) Run(ctx context.Context) error { return o.fn(ctx) }
23-
24-
func taskFn(typ string, fn func(context.Context) error) Task {
25-
return &taskFunc{typ: typ, fn: fn}
26-
}
27-
2815
func TestContext(t *testing.T) {
2916
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
3017
defer cancel()
@@ -59,7 +46,7 @@ func TestRun(t *testing.T) {
5946
expCalls := 0
6047
for range 16 {
6148
expCalls++
62-
task := taskFn("test.run", func(ctx context.Context) error {
49+
task := Func("test.run", func(ctx context.Context) error {
6350
calls.Add(1)
6451
return nil
6552
})
@@ -85,7 +72,7 @@ func TestRun(t *testing.T) {
8572
sentinel := errors.New("sentinel")
8673
for range 16 {
8774
expCalls++
88-
task := taskFn("test.run", func(ctx context.Context) error {
75+
task := Func("test.run", func(ctx context.Context) error {
8976
calls.Add(1)
9077
return sentinel
9178
})
@@ -110,7 +97,7 @@ func TestRun(t *testing.T) {
11097
sentinel := errors.New("sentinel")
11198
for range 16 {
11299
expCalls++
113-
task := taskFn("test.run", func(ctx context.Context) error {
100+
task := Func("test.run", func(ctx context.Context) error {
114101
calls.Add(1)
115102
return sentinel
116103
})
@@ -137,7 +124,7 @@ func TestMiddleware(t *testing.T) {
137124
hrFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
138125
for i := range 10 {
139126
typ := fmt.Sprintf("test-task-%v", i)
140-
task := taskFn(typ, func(ctx context.Context) error {
127+
task := Func(typ, func(ctx context.Context) error {
141128
return nil
142129
})
143130

@@ -164,7 +151,7 @@ func TestMiddleware(t *testing.T) {
164151
hrFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
165152
for i := range 10 {
166153
typ := fmt.Sprintf("test-task-%v", i)
167-
task := taskFn(typ, func(ctx context.Context) error {
154+
task := Func(typ, func(ctx context.Context) error {
168155
return nil
169156
})
170157
err := Run(r.Context(), task)

internal/api/e2e_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,43 @@ func runVerifyBeforeUserCreatedHook(
8787
return latest
8888
}
8989

90+
func runVerifyAfterUserCreatedHook(
91+
t *testing.T,
92+
inst *e2ehooks.Instance,
93+
expUser *models.User,
94+
) *models.User {
95+
var latest *models.User
96+
t.Run("VerifyAfterUserCreatedHook", func(t *testing.T) {
97+
defer inst.HookRecorder.AfterUserCreated.ClearCalls()
98+
99+
calls := inst.HookRecorder.AfterUserCreated.GetCalls()
100+
require.Equal(t, 1, len(calls))
101+
call := calls[0]
102+
103+
hookReq := &v0hooks.AfterUserCreatedInput{}
104+
err := call.Unmarshal(hookReq)
105+
require.NoError(t, err)
106+
require.Equal(t, v0hooks.AfterUserCreated, hookReq.Metadata.Name)
107+
108+
u := hookReq.User
109+
require.Equal(t, expUser.ID, u.ID)
110+
require.Equal(t, expUser.Aud, u.Aud)
111+
require.Equal(t, expUser.Email, u.Email)
112+
require.Equal(t, expUser.AppMetaData, u.AppMetaData)
113+
114+
require.False(t, u.CreatedAt.IsZero())
115+
require.False(t, u.UpdatedAt.IsZero())
116+
117+
err = expUser.Confirm(inst.Conn)
118+
require.NoError(t, err)
119+
120+
latest, err = models.FindUserByID(inst.Conn, expUser.ID)
121+
require.NoError(t, err)
122+
require.NotNil(t, latest)
123+
})
124+
return latest
125+
}
126+
90127
func getAccessToken(
91128
ctx context.Context,
92129
t *testing.T,
@@ -208,6 +245,7 @@ func TestE2EHooks(t *testing.T) {
208245
require.Equal(t, email, res.Email.String())
209246

210247
runVerifyBeforeUserCreatedHook(t, inst, res)
248+
runVerifyAfterUserCreatedHook(t, inst, res)
211249
})
212250

213251
t.Run("SignupPhone", func(t *testing.T) {
@@ -224,6 +262,8 @@ func TestE2EHooks(t *testing.T) {
224262
require.Equal(t, phone, res.Phone.String())
225263

226264
runVerifyBeforeUserCreatedHook(t, inst, res)
265+
runVerifyAfterUserCreatedHook(t, inst, res)
266+
227267
})
228268

229269
t.Run("SignupAnonymously", func(t *testing.T) {
@@ -235,6 +275,8 @@ func TestE2EHooks(t *testing.T) {
235275
require.NoError(t, err)
236276

237277
runVerifyBeforeUserCreatedHook(t, inst, res.User)
278+
runVerifyAfterUserCreatedHook(t, inst, res.User)
279+
238280
})
239281

240282
t.Run("ExternalCallback", func(t *testing.T) {
@@ -246,6 +288,8 @@ func TestE2EHooks(t *testing.T) {
246288
require.NoError(t, err)
247289

248290
runVerifyBeforeUserCreatedHook(t, inst, res.User)
291+
runVerifyAfterUserCreatedHook(t, inst, res.User)
292+
249293
})
250294

251295
t.Run("AdminEndpoints", func(t *testing.T) {
@@ -273,6 +317,8 @@ func TestE2EHooks(t *testing.T) {
273317
require.NoError(t, err)
274318

275319
runVerifyBeforeUserCreatedHook(t, inst, res)
320+
runVerifyAfterUserCreatedHook(t, inst, res)
321+
276322
})
277323

278324
t.Run("AdminGenerateLink", func(t *testing.T) {
@@ -304,6 +350,7 @@ func TestE2EHooks(t *testing.T) {
304350
require.NoError(t, err)
305351

306352
runVerifyBeforeUserCreatedHook(t, inst, &res.User)
353+
runVerifyAfterUserCreatedHook(t, inst, &res.User)
307354
})
308355

309356
t.Run("InviteVerification", func(t *testing.T) {
@@ -332,6 +379,7 @@ func TestE2EHooks(t *testing.T) {
332379
require.NoError(t, err)
333380

334381
runVerifyBeforeUserCreatedHook(t, inst, &res.User)
382+
runVerifyAfterUserCreatedHook(t, inst, &res.User)
335383
})
336384
})
337385
})
@@ -372,6 +420,7 @@ func TestE2EHooks(t *testing.T) {
372420
require.Equal(t, email, mfaUser.Email.String())
373421

374422
mfaUser = runVerifyBeforeUserCreatedHook(t, inst, mfaUser)
423+
runVerifyAfterUserCreatedHook(t, inst, mfaUser)
375424
require.NotNil(t, mfaUser)
376425
mfaUserAccessToken = getAccessToken(
377426
ctx, t, inst, string(mfaUser.Email), defaultPassword)
@@ -562,6 +611,7 @@ func TestE2EHooks(t *testing.T) {
562611
require.Equal(t, email, res.Email.String())
563612

564613
currentUser = runVerifyBeforeUserCreatedHook(t, inst, res)
614+
runVerifyAfterUserCreatedHook(t, inst, res)
565615
require.NotNil(t, currentUser)
566616
inst.HookRecorder.CustomizeAccessToken.ClearCalls()
567617
}

internal/api/external.go

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
218218
}
219219
}
220220

221+
var createdUser bool
221222
var user *models.User
222223
var token *AccessTokenResponse
223224
err = db.Transaction(func(tx *storage.Connection) error {
@@ -231,7 +232,8 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
231232
return terr
232233
}
233234
} else {
234-
if user, terr = a.createAccountFromExternalIdentity(tx, r, userData, providerType, emailOptional); terr != nil {
235+
createdUser = true
236+
if _, user, terr = a.createAccountFromExternalIdentity(tx, r, userData, providerType, emailOptional); terr != nil {
235237
return terr
236238
}
237239
}
@@ -253,10 +255,14 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
253255
}
254256
return nil
255257
})
256-
257258
if err != nil {
258259
return err
259260
}
261+
if createdUser {
262+
if err := a.triggerAfterUserCreated(r, db, user); err != nil {
263+
return err
264+
}
265+
}
260266

261267
// Record login for analytics - only when token is issued (not during pkce authorize)
262268
if token != nil {
@@ -290,7 +296,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
290296
return nil
291297
}
292298

293-
func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.Request, userData *provider.UserProvidedData, providerType string, emailOptional bool) (*models.User, error) {
299+
func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.Request, userData *provider.UserProvidedData, providerType string, emailOptional bool) (models.AccountLinkingDecision, *models.User, error) {
294300
ctx := r.Context()
295301
aud := a.requestAud(ctx, r)
296302
config := a.config
@@ -304,28 +310,28 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.
304310

305311
decision, terr := models.DetermineAccountLinking(tx, config, userData.Emails, aud, providerType, userData.Metadata.Subject)
306312
if terr != nil {
307-
return nil, terr
313+
return 0, nil, terr
308314
}
309315

310316
switch decision.Decision {
311317
case models.LinkAccount:
312318
user = decision.User
313319

314320
if identity, terr = a.createNewIdentity(tx, user, providerType, identityData); terr != nil {
315-
return nil, terr
321+
return 0, nil, terr
316322
}
317323

318324
if terr = user.UpdateUserMetaData(tx, identityData); terr != nil {
319-
return nil, terr
325+
return 0, nil, terr
320326
}
321327

322328
if terr = user.UpdateAppMetaDataProviders(tx); terr != nil {
323-
return nil, terr
329+
return 0, nil, terr
324330
}
325331

326332
case models.CreateAccount:
327333
if config.DisableSignup {
328-
return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance")
334+
return 0, nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance")
329335
}
330336

331337
params := &SignupParams{
@@ -352,15 +358,15 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.
352358
// transaction
353359
user, terr = params.ToUserModel(isSSOUser)
354360
if terr != nil {
355-
return nil, terr
361+
return 0, nil, terr
356362
}
357363

358364
if user, terr = a.signupNewUser(tx, user); terr != nil {
359-
return nil, terr
365+
return 0, nil, terr
360366
}
361367

362368
if identity, terr = a.createNewIdentity(tx, user, providerType, identityData); terr != nil {
363-
return nil, terr
369+
return 0, nil, terr
364370
}
365371
user.Identities = append(user.Identities, *identity)
366372

@@ -370,24 +376,24 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.
370376

371377
identity.IdentityData = identityData
372378
if terr = tx.UpdateOnly(identity, "identity_data", "last_sign_in_at"); terr != nil {
373-
return nil, terr
379+
return 0, nil, terr
374380
}
375381
if terr = user.UpdateUserMetaData(tx, identityData); terr != nil {
376-
return nil, terr
382+
return 0, nil, terr
377383
}
378384
if terr = user.UpdateAppMetaDataProviders(tx); terr != nil {
379-
return nil, terr
385+
return 0, nil, terr
380386
}
381387

382388
case models.MultipleAccounts:
383-
return nil, apierrors.NewInternalServerError("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain)
389+
return 0, nil, apierrors.NewInternalServerError("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain)
384390

385391
default:
386-
return nil, apierrors.NewInternalServerError("Unknown automatic linking decision: %v", decision.Decision)
392+
return 0, nil, apierrors.NewInternalServerError("Unknown automatic linking decision: %v", decision.Decision)
387393
}
388394

389395
if user.IsBanned() {
390-
return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeUserBanned, "User is banned")
396+
return 0, nil, apierrors.NewForbiddenError(apierrors.ErrorCodeUserBanned, "User is banned")
391397
}
392398

393399
hasEmails := providerType != "web3" && !(emailOptional && decision.CandidateEmail.Email == "")
@@ -398,44 +404,44 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.
398404
// need to be removed when a new oauth identity is being added
399405
// to prevent pre-account takeover attacks from happening.
400406
if terr = user.RemoveUnconfirmedIdentities(tx, identity); terr != nil {
401-
return nil, apierrors.NewInternalServerError("Error updating user").WithInternalError(terr)
407+
return 0, nil, apierrors.NewInternalServerError("Error updating user").WithInternalError(terr)
402408
}
403409
if decision.CandidateEmail.Verified || config.Mailer.Autoconfirm {
404410
if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.UserSignedUpAction, "", map[string]interface{}{
405411
"provider": providerType,
406412
}); terr != nil {
407-
return nil, terr
413+
return 0, nil, terr
408414
}
409415
// fall through to auto-confirm and issue token
410416
if terr = user.Confirm(tx); terr != nil {
411-
return nil, apierrors.NewInternalServerError("Error updating user").WithInternalError(terr)
417+
return 0, nil, apierrors.NewInternalServerError("Error updating user").WithInternalError(terr)
412418
}
413419
} else {
414420
emailConfirmationSent := false
415421
if decision.CandidateEmail.Email != "" {
416422
if terr = a.sendConfirmation(r, tx, user, models.ImplicitFlow); terr != nil {
417-
return nil, terr
423+
return 0, nil, terr
418424
}
419425
emailConfirmationSent = true
420426
}
421427

422428
if !config.Mailer.AllowUnverifiedEmailSignIns {
423429
if emailConfirmationSent {
424-
return nil, storage.NewCommitWithError(apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType)))
430+
return 0, nil, storage.NewCommitWithError(apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType)))
425431
}
426432

427-
return nil, storage.NewCommitWithError(apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType)))
433+
return 0, nil, storage.NewCommitWithError(apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType)))
428434
}
429435
}
430436
} else {
431437
if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.LoginAction, "", map[string]interface{}{
432438
"provider": providerType,
433439
}); terr != nil {
434-
return nil, terr
440+
return 0, nil, terr
435441
}
436442
}
437443

438-
return user, nil
444+
return decision.Decision, user, nil
439445
}
440446

441447
func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *provider.UserProvidedData, inviteToken, providerType string) (*models.User, error) {

0 commit comments

Comments
 (0)