Skip to content

Commit 060a992

Browse files
cstocktonChris Stockton
andauthored
fix: ensure request context exists in API db operations (#2171)
Some API paths were using `a.db` directly, which bypasses request scoped timeouts/cancellation. This change ensures db work is performed with a connection derived from the current request context. No behavioral are changes are intended here. Co-authored-by: Chris Stockton <[email protected]>
1 parent 0bd1c28 commit 060a992

File tree

9 files changed

+27
-20
lines changed

9 files changed

+27
-20
lines changed

internal/api/admin.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {
514514
user := getUser(ctx)
515515
config := a.config
516516
adminUser := getAdminUser(ctx)
517+
db := a.db.WithContext(ctx)
517518

518519
// ShouldSoftDelete defaults to false
519520
params := &adminUserDeleteParams{}
@@ -525,7 +526,7 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {
525526
}
526527
}
527528

528-
err := a.db.Transaction(func(tx *storage.Connection) error {
529+
err := db.Transaction(func(tx *storage.Connection) error {
529530
if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, adminUser, models.UserDeletedAction, "", map[string]interface{}{
530531
"user_id": user.ID,
531532
"user_email": user.Email,
@@ -575,8 +576,9 @@ func (a *API) adminUserDeleteFactor(w http.ResponseWriter, r *http.Request) erro
575576
config := a.config
576577
user := getUser(ctx)
577578
factor := getFactor(ctx)
579+
db := a.db.WithContext(ctx)
578580

579-
err := a.db.Transaction(func(tx *storage.Connection) error {
581+
err := db.Transaction(func(tx *storage.Connection) error {
580582
if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.DeleteFactorAction, r.RemoteAddr, map[string]interface{}{
581583
"user_id": user.ID,
582584
"factor_id": factor.ID,
@@ -608,12 +610,13 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
608610
user := getUser(ctx)
609611
adminUser := getAdminUser(ctx)
610612
params := &adminUserUpdateFactorParams{}
613+
db := a.db.WithContext(ctx)
611614

612615
if err := retrieveRequestParams(r, params); err != nil {
613616
return err
614617
}
615618

616-
err := a.db.Transaction(func(tx *storage.Connection) error {
619+
err := db.Transaction(func(tx *storage.Connection) error {
617620
if params.FriendlyName != "" {
618621
if terr := factor.UpdateFriendlyName(tx, params.FriendlyName); terr != nil {
619622
return terr

internal/api/external.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
8383

8484
flowStateID := ""
8585
if isPKCEFlow(flowType) {
86-
flowState, err := generateFlowState(a.db, providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil)
86+
flowState, err := generateFlowState(db, providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil)
8787
if err != nil {
8888
return "", err
8989
}
@@ -200,7 +200,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
200200
var flowState *models.FlowState
201201
// if there's a non-empty FlowStateID we perform PKCE Flow
202202
if flowStateID := getFlowStateID(ctx); flowStateID != "" {
203-
flowState, err = models.FindFlowStateByID(a.db, flowStateID)
203+
flowState, err = models.FindFlowStateByID(db, flowStateID)
204204
if models.IsNotFoundError(err) {
205205
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeFlowStateNotFound, "Flow state not found").WithInternalError(err)
206206
} else if err != nil {
@@ -506,7 +506,7 @@ func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *p
506506
return user, nil
507507
}
508508

509-
func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.Context, error) {
509+
func (a *API) loadExternalState(ctx context.Context, r *http.Request, db *storage.Connection) (context.Context, error) {
510510
var state string
511511
switch r.Method {
512512
case http.MethodPost:
@@ -564,7 +564,7 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.C
564564
if err != nil {
565565
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (linking_target_id must be UUID)")
566566
}
567-
u, err := models.FindUserByID(a.db, linkingTargetUserID)
567+
u, err := models.FindUserByID(db, linkingTargetUserID)
568568
if err != nil {
569569
if models.IsNotFoundError(err) {
570570
return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeUserNotFound, "Linking target user not found")

internal/api/external_oauth.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ type OAuthProviderData struct {
2727
// extracting the provider requested
2828
func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Context, error) {
2929
ctx := r.Context()
30+
db := a.db.WithContext(ctx)
31+
3032
oauthToken := r.URL.Query().Get("oauth_token")
3133
if oauthToken != "" {
3234
ctx = withRequestToken(ctx, oauthToken)
@@ -37,7 +39,7 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con
3739
}
3840

3941
var err error
40-
ctx, err = a.loadExternalState(ctx, r)
42+
ctx, err = a.loadExternalState(ctx, r, db)
4143
if err != nil {
4244
u, uerr := url.ParseRequestURI(a.config.SiteURL)
4345
if uerr != nil {

internal/api/hooks.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,31 @@ import (
1414

1515
func (a *API) triggerBeforeUserCreated(
1616
r *http.Request,
17-
conn *storage.Connection,
17+
db *storage.Connection,
1818
user *models.User,
1919
) error {
2020
if !a.hooksMgr.Enabled(v0hooks.BeforeUserCreated) {
2121
return nil
2222
}
23-
if err := checkTX(conn); err != nil {
23+
if err := checkTX(db); err != nil {
2424
return err
2525
}
2626

2727
req := v0hooks.NewBeforeUserCreatedInput(r, user)
2828
res := new(v0hooks.BeforeUserCreatedOutput)
29-
return a.hooksMgr.InvokeHook(conn, r, req, res)
29+
return a.hooksMgr.InvokeHook(db, r, req, res)
3030
}
3131

3232
func (a *API) triggerBeforeUserCreatedExternal(
3333
r *http.Request,
34-
conn *storage.Connection,
34+
db *storage.Connection,
3535
userData *provider.UserProvidedData,
3636
providerType string,
3737
) error {
3838
if !a.hooksMgr.Enabled(v0hooks.BeforeUserCreated) {
3939
return nil
4040
}
41-
if err := checkTX(conn); err != nil {
41+
if err := checkTX(db); err != nil {
4242
return err
4343
}
4444

@@ -55,7 +55,7 @@ func (a *API) triggerBeforeUserCreatedExternal(
5555
err error
5656
decision models.AccountLinkingResult
5757
)
58-
err = a.db.Transaction(func(tx *storage.Connection) error {
58+
err = db.Transaction(func(tx *storage.Connection) error {
5959
decision, err = models.DetermineAccountLinking(
6060
tx, config, userData.Emails, aud,
6161
providerType, userData.Metadata.Subject)
@@ -93,7 +93,7 @@ func (a *API) triggerBeforeUserCreatedExternal(
9393
if err != nil {
9494
return err
9595
}
96-
return a.triggerBeforeUserCreated(r, conn, user)
96+
return a.triggerBeforeUserCreated(r, db, user)
9797
}
9898

9999
func checkTX(conn *storage.Connection) error {

internal/api/identity.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515

1616
func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error {
1717
ctx := r.Context()
18+
db := a.db.WithContext(ctx)
1819
config := a.config
1920

2021
claims := getClaims(ctx)
@@ -49,7 +50,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error {
4950
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeIdentityNotFound, "Identity doesn't exist")
5051
}
5152

52-
err = a.db.Transaction(func(tx *storage.Connection) error {
53+
err = db.Transaction(func(tx *storage.Connection) error {
5354
if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.IdentityUnlinkAction, "", map[string]interface{}{
5455
"identity_id": identityToBeDeleted.ID,
5556
"provider": identityToBeDeleted.Provider,

internal/api/magic_link.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error {
130130
}
131131

132132
if isPKCEFlow(flowType) {
133-
if _, err = generateFlowState(a.db, models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID); err != nil {
133+
if _, err = generateFlowState(db, models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID); err != nil {
134134
return err
135135
}
136136
}

internal/api/mfa.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error
418418
},
419419
}
420420
output := v0hooks.SendSMSOutput{}
421-
err := a.hooksMgr.InvokeHook(a.db, r, &input, &output)
421+
err := a.hooksMgr.InvokeHook(db, r, &input, &output)
422422
if err != nil {
423423
return apierrors.NewInternalServerError("error invoking hook")
424424
}

internal/api/token.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
165165
output.Message = v0hooks.DefaultPasswordHookRejectionMessage
166166
}
167167
if output.ShouldLogoutUser {
168-
if err := models.Logout(a.db, user.ID); err != nil {
168+
if err := models.Logout(db, user.ID); err != nil {
169169
return err
170170
}
171171
}

internal/api/token_refresh.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
1919
return err
2020
}
2121

22-
tokenResponse, err := a.tokenService.RefreshTokenGrant(ctx, a.db, r, tokens.RefreshTokenGrantParams{
22+
db := a.db.WithContext(ctx)
23+
tokenResponse, err := a.tokenService.RefreshTokenGrant(ctx, db, r, tokens.RefreshTokenGrantParams{
2324
RefreshToken: params.RefreshToken,
2425
})
2526
if err != nil {

0 commit comments

Comments
 (0)