Skip to content

Commit d8a64e3

Browse files
authored
fix: add isolation level in transaction repository (#824)
Currently there were chances to have race conditions while writing to transaction repository. I have added a test to verify it doesn't happen. Database is using serializable as the isolation level to avoid overlapping transactions. Signed-off-by: Kush Sharma <[email protected]>
1 parent 1895ed7 commit d8a64e3

File tree

7 files changed

+222
-142
lines changed

7 files changed

+222
-142
lines changed

internal/api/v1beta1/org.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package v1beta1
33
import (
44
"context"
55

6+
"github.com/raystack/frontier/core/serviceuser"
7+
68
"github.com/raystack/frontier/core/authenticate"
79

810
"go.uber.org/zap"
@@ -318,13 +320,15 @@ func (h Handler) ListOrganizationServiceUsers(ctx context.Context, request *fron
318320
}
319321
}
320322

321-
users, err := h.serviceUserService.ListByOrg(ctx, orgResp.ID)
323+
usersList, err := h.serviceUserService.List(ctx, serviceuser.Filter{
324+
OrgID: orgResp.ID,
325+
})
322326
if err != nil {
323327
return nil, err
324328
}
325329

326330
var usersPB []*frontierv1beta1.ServiceUser
327-
for _, rel := range users {
331+
for _, rel := range usersList {
328332
u, err := transformServiceUserToPB(rel)
329333
if err != nil {
330334
return nil, err

internal/api/v1beta1/org_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,9 @@ func TestHandler_ListOrganizationServiceUsers(t *testing.T) {
790790
for _, u := range testUserMap {
791791
testUserList = append(testUserList, u)
792792
}
793-
us.EXPECT().ListByOrg(mock.AnythingOfType("context.backgroundCtx"), testOrgID).Return([]serviceuser.ServiceUser{
793+
us.EXPECT().List(mock.AnythingOfType("context.backgroundCtx"), serviceuser.Filter{
794+
OrgID: testOrgID,
795+
}).Return([]serviceuser.ServiceUser{
794796
{
795797
ID: "9f256f86-31a3-11ec-8d3d-0242ac130003",
796798
Title: "Sample Service User",

internal/store/postgres/billing_transactions_repository.go

Lines changed: 125 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ import (
66
"encoding/json"
77
"errors"
88
"fmt"
9+
"math/rand"
910
"strings"
1011
"time"
1112

1213
"github.com/raystack/frontier/billing/customer"
14+
1315
"github.com/raystack/frontier/internal/bootstrap/schema"
1416

1517
"github.com/jackc/pgconn"
@@ -81,131 +83,143 @@ func NewBillingTransactionRepository(dbc *db.Client) *BillingTransactionReposito
8183
}
8284
}
8385

86+
var (
87+
maxRetries = 5
88+
// Error codes from https://www.postgresql.org/docs/current/errcodes-appendix.html
89+
serializationFailureCode = "40001"
90+
deadlockDetectedCode = "40P01"
91+
)
92+
93+
func (r BillingTransactionRepository) withRetry(ctx context.Context, fn func() error) error {
94+
var lastErr error
95+
for i := 0; i < maxRetries && ctx.Err() == nil; i++ {
96+
err := fn()
97+
if err == nil {
98+
return nil
99+
}
100+
101+
var pqErr *pgconn.PgError
102+
if errors.As(err, &pqErr) {
103+
// Retry on serialization failures or deadlocks
104+
if pqErr.Code == serializationFailureCode || pqErr.Code == deadlockDetectedCode {
105+
lastErr = err
106+
// Exponential backoff with jitter
107+
backoff := time.Duration(1<<uint(i)) * 50 * time.Millisecond
108+
jitter := time.Duration(rand.Int63n(int64(backoff / 2)))
109+
time.Sleep(backoff + jitter)
110+
continue
111+
}
112+
}
113+
return err // Return immediately for other errors
114+
}
115+
return fmt.Errorf("max retries exceeded: %w", lastErr)
116+
}
117+
84118
func (r BillingTransactionRepository) CreateEntry(ctx context.Context, debitEntry credit.Transaction,
85119
creditEntry credit.Transaction) ([]credit.Transaction, error) {
86-
var customerAcc customer.Customer
120+
txOpts := sql.TxOptions{
121+
Isolation: sql.LevelSerializable,
122+
ReadOnly: false,
123+
}
124+
87125
var err error
126+
var debitModel Transaction
127+
var creditModel Transaction
128+
var customerAcc customer.Customer
129+
88130
if debitEntry.CustomerID != schema.PlatformOrgID.String() {
89-
// only fetch if it's a customer debit entry
90131
customerAcc, err = r.customerRepo.GetByID(ctx, debitEntry.CustomerID)
91132
if err != nil {
92133
return nil, fmt.Errorf("failed to get customer account: %w", err)
93134
}
94135
}
95136

96-
if debitEntry.Metadata == nil {
97-
debitEntry.Metadata = make(map[string]any)
137+
var creditReturnedEntry, debitReturnedEntry credit.Transaction
138+
err = r.withRetry(ctx, func() error {
139+
return r.dbc.WithTxn(ctx, txOpts, func(tx *sqlx.Tx) error {
140+
if debitEntry.CustomerID != schema.PlatformOrgID.String() {
141+
// check for balance only when deducting from customer account
142+
currentBalance, err := r.getBalanceInTx(ctx, tx, debitEntry.CustomerID)
143+
if err != nil {
144+
return fmt.Errorf("failed to get balance: %w", err)
145+
}
146+
147+
if err := isSufficientBalance(customerAcc.CreditMin, currentBalance, debitEntry.Amount); err != nil {
148+
return err
149+
}
150+
}
151+
152+
if err := r.createTransactionEntry(ctx, tx, debitEntry, &debitModel); err != nil {
153+
return fmt.Errorf("failed to create debit entry: %w", err)
154+
}
155+
if err := r.createTransactionEntry(ctx, tx, creditEntry, &creditModel); err != nil {
156+
return fmt.Errorf("failed to create credit entry: %w", err)
157+
}
158+
return nil
159+
})
160+
})
161+
if err != nil {
162+
if errors.Is(err, credit.ErrAlreadyApplied) {
163+
return nil, credit.ErrAlreadyApplied
164+
} else if errors.Is(err, credit.ErrInsufficientCredits) {
165+
return nil, credit.ErrInsufficientCredits
166+
}
167+
return nil, fmt.Errorf("failed to create transaction entry: %w", err)
98168
}
99-
debitMetadata, err := json.Marshal(debitEntry.Metadata)
169+
170+
creditReturnedEntry, err = creditModel.transform()
100171
if err != nil {
101-
return nil, err
102-
}
103-
debitRecord := goqu.Record{
104-
"account_id": debitEntry.CustomerID,
105-
"description": debitEntry.Description,
106-
"type": debitEntry.Type,
107-
"source": debitEntry.Source,
108-
"amount": debitEntry.Amount,
109-
"user_id": debitEntry.UserID,
110-
"metadata": debitMetadata,
111-
"created_at": goqu.L("now()"),
112-
"updated_at": goqu.L("now()"),
172+
return nil, fmt.Errorf("failed to transform credit entry: %w", err)
113173
}
114-
if debitEntry.ID != "" {
115-
debitRecord["id"] = debitEntry.ID
174+
debitReturnedEntry, err = debitModel.transform()
175+
if err != nil {
176+
return nil, fmt.Errorf("failed to transform debit entry: %w", err)
116177
}
178+
return []credit.Transaction{debitReturnedEntry, creditReturnedEntry}, nil
179+
}
117180

118-
if creditEntry.Metadata == nil {
119-
creditEntry.Metadata = make(map[string]any)
181+
func (r BillingTransactionRepository) createTransactionEntry(ctx context.Context, tx *sqlx.Tx, entry credit.Transaction, model *Transaction) error {
182+
if entry.Metadata == nil {
183+
entry.Metadata = make(map[string]any)
120184
}
121-
creditMetadata, err := json.Marshal(creditEntry.Metadata)
185+
metadata, err := json.Marshal(entry.Metadata)
122186
if err != nil {
123-
return nil, err
124-
}
125-
creditRecord := goqu.Record{
126-
"account_id": creditEntry.CustomerID,
127-
"description": creditEntry.Description,
128-
"type": creditEntry.Type,
129-
"source": creditEntry.Source,
130-
"amount": creditEntry.Amount,
131-
"user_id": creditEntry.UserID,
132-
"metadata": creditMetadata,
187+
return err
188+
}
189+
190+
record := goqu.Record{
191+
"account_id": entry.CustomerID,
192+
"description": entry.Description,
193+
"type": entry.Type,
194+
"source": entry.Source,
195+
"amount": entry.Amount,
196+
"user_id": entry.UserID,
197+
"metadata": metadata,
133198
"created_at": goqu.L("now()"),
134199
"updated_at": goqu.L("now()"),
135200
}
136-
if creditEntry.ID != "" {
137-
creditRecord["id"] = creditEntry.ID
201+
if entry.ID != "" {
202+
record["id"] = entry.ID
138203
}
139204

140-
var creditReturnedEntry, debitReturnedEntry credit.Transaction
141-
if err := r.dbc.WithTxn(ctx, sql.TxOptions{}, func(tx *sqlx.Tx) error {
142-
// check if balance is enough if it's a customer entry
143-
if customerAcc.ID != "" {
144-
currentBalance, err := r.getBalanceInTx(ctx, tx, customerAcc.ID)
145-
if err != nil {
146-
return fmt.Errorf("failed to apply transaction: %w", err)
147-
}
148-
if err := isSufficientBalance(customerAcc.CreditMin, currentBalance, debitEntry.Amount); err != nil {
149-
return err
150-
}
151-
}
152-
153-
var debitModel Transaction
154-
var creditModel Transaction
155-
query, params, err := dialect.Insert(TABLE_BILLING_TRANSACTIONS).Rows(debitRecord).Returning(&Transaction{}).ToSQL()
156-
if err != nil {
157-
return fmt.Errorf("%w: %s", parseErr, err)
158-
}
159-
if err = r.dbc.WithTimeout(ctx, TABLE_BILLING_TRANSACTIONS, "Create", func(ctx context.Context) error {
160-
return r.dbc.QueryRowxContext(ctx, query, params...).StructScan(&debitModel)
161-
}); err != nil {
162-
var pqErr *pgconn.PgError
163-
if errors.As(err, &pqErr) && (pqErr.Code == "23505") { // handle unique key violations
164-
if pqErr.ConstraintName == "billing_transactions_pkey" { // primary key violation
165-
return credit.ErrAlreadyApplied
166-
}
167-
// add other specific unique key violations here if needed
168-
}
169-
return fmt.Errorf("%w: %s", dbErr, err)
170-
}
171-
172-
query, params, err = dialect.Insert(TABLE_BILLING_TRANSACTIONS).Rows(creditRecord).Returning(&Transaction{}).ToSQL()
173-
if err != nil {
174-
return fmt.Errorf("%w: %s", parseErr, err)
175-
}
176-
if err = r.dbc.WithTimeout(ctx, TABLE_BILLING_TRANSACTIONS, "Create", func(ctx context.Context) error {
177-
return r.dbc.QueryRowxContext(ctx, query, params...).StructScan(&creditModel)
178-
}); err != nil {
179-
var pqErr *pgconn.PgError
180-
if errors.As(err, &pqErr) && (pqErr.Code == "23505") { // handle unique key violations
181-
if pqErr.ConstraintName == "billing_transactions_pkey" { // primary key violation
182-
return credit.ErrAlreadyApplied
183-
}
184-
// add other specific unique key violations here if needed
185-
}
186-
return fmt.Errorf("%w: %s", dbErr, err)
187-
}
188-
189-
creditReturnedEntry, err = creditModel.transform()
190-
if err != nil {
191-
return fmt.Errorf("failed to transform credit entry: %w", err)
192-
}
193-
debitReturnedEntry, err = debitModel.transform()
194-
if err != nil {
195-
return fmt.Errorf("failed to transform debit entry: %w", err)
196-
}
205+
query, params, err := dialect.Insert(TABLE_BILLING_TRANSACTIONS).Rows(record).Returning(&Transaction{}).ToSQL()
206+
if err != nil {
207+
return fmt.Errorf("%w: %w", parseErr, err)
208+
}
197209

198-
return nil
210+
if err = r.dbc.WithTimeout(ctx, TABLE_BILLING_TRANSACTIONS, "Create", func(ctx context.Context) error {
211+
return tx.QueryRowxContext(ctx, query, params...).StructScan(model)
199212
}); err != nil {
200-
if errors.Is(err, credit.ErrAlreadyApplied) {
201-
return nil, credit.ErrAlreadyApplied
202-
} else if errors.Is(err, credit.ErrInsufficientCredits) {
203-
return nil, credit.ErrInsufficientCredits
213+
var pqErr *pgconn.PgError
214+
if errors.As(err, &pqErr) && (pqErr.Code == "23505") {
215+
if pqErr.ConstraintName == "billing_transactions_pkey" {
216+
return credit.ErrAlreadyApplied
217+
}
204218
}
205-
return nil, fmt.Errorf("failed to create transaction entry: %w", err)
219+
return fmt.Errorf("%w: %w", dbErr, err)
206220
}
207221

208-
return []credit.Transaction{debitReturnedEntry, creditReturnedEntry}, nil
222+
return nil
209223
}
210224

211225
// isSufficientBalance checks if the customer has enough balance to perform the transaction.
@@ -328,6 +342,7 @@ func (r BillingTransactionRepository) getDebitBalance(ctx context.Context, tx *s
328342
"account_id": accountID,
329343
"type": credit.DebitType,
330344
})
345+
331346
query, params, err := stmt.ToSQL()
332347
if err != nil {
333348
return nil, fmt.Errorf("%w: %s", parseErr, err)
@@ -347,6 +362,7 @@ func (r BillingTransactionRepository) getCreditBalance(ctx context.Context, tx *
347362
"account_id": accountID,
348363
"type": credit.CreditType,
349364
})
365+
350366
query, params, err := stmt.ToSQL()
351367
if err != nil {
352368
return nil, fmt.Errorf("%w: %s", parseErr, err)
@@ -388,11 +404,17 @@ func (r BillingTransactionRepository) getBalanceInTx(ctx context.Context, tx *sq
388404
// in transaction table till now.
389405
func (r BillingTransactionRepository) GetBalance(ctx context.Context, accountID string) (int64, error) {
390406
var amount int64
391-
if err := r.dbc.WithTxn(ctx, sql.TxOptions{}, func(tx *sqlx.Tx) error {
392-
var err error
393-
amount, err = r.getBalanceInTx(ctx, tx, accountID)
394-
return err
395-
}); err != nil {
407+
err := r.withRetry(ctx, func() error {
408+
return r.dbc.WithTxn(ctx, sql.TxOptions{
409+
Isolation: sql.LevelSerializable,
410+
ReadOnly: true,
411+
}, func(tx *sqlx.Tx) error {
412+
var err error
413+
amount, err = r.getBalanceInTx(ctx, tx, accountID)
414+
return err
415+
})
416+
})
417+
if err != nil {
396418
return 0, fmt.Errorf("failed to get balance: %w", err)
397419
}
398420
return amount, nil

0 commit comments

Comments
 (0)