Skip to content

Commit bb512e7

Browse files
committed
Change Wire DB operations into using a runtime type assertion
1 parent 92e95e4 commit bb512e7

File tree

5 files changed

+271
-181
lines changed

5 files changed

+271
-181
lines changed

acme/challenge.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,10 @@ type wireOidcPayload struct {
393393
}
394394

395395
func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, payload []byte) error {
396+
wireDB, ok := db.(WireDB)
397+
if !ok {
398+
return NewErrorISE("db %T is not a WireDB", db)
399+
}
396400
prov, ok := ProvisionerFromContext(ctx)
397401
if !ok {
398402
return NewErrorISE("missing provisioner")
@@ -472,7 +476,7 @@ func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO
472476
return WrapErrorISE(err, "error updating challenge")
473477
}
474478

475-
orders, err := db.GetAllOrdersByAccountID(ctx, ch.AccountID)
479+
orders, err := wireDB.GetAllOrdersByAccountID(ctx, ch.AccountID)
476480
if err != nil {
477481
return WrapErrorISE(err, "could not retrieve current order by account id")
478482
}
@@ -481,7 +485,7 @@ func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO
481485
}
482486

483487
order := orders[len(orders)-1]
484-
if err := db.CreateOidcToken(ctx, order, transformedIDToken); err != nil {
488+
if err := wireDB.CreateOidcToken(ctx, order, transformedIDToken); err != nil {
485489
return WrapErrorISE(err, "failed storing OIDC id token")
486490
}
487491

@@ -523,6 +527,10 @@ type wireDpopPayload struct {
523527
}
524528

525529
func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, accountJWK *jose.JSONWebKey, payload []byte) error {
530+
wireDB, ok := db.(WireDB)
531+
if !ok {
532+
return NewErrorISE("db %T is not a WireDB", db)
533+
}
526534
prov, ok := ProvisionerFromContext(ctx)
527535
if !ok {
528536
return NewErrorISE("missing provisioner")
@@ -586,7 +594,7 @@ func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, accountJWK *j
586594
return WrapErrorISE(err, "error updating challenge")
587595
}
588596

589-
orders, err := db.GetAllOrdersByAccountID(ctx, ch.AccountID)
597+
orders, err := wireDB.GetAllOrdersByAccountID(ctx, ch.AccountID)
590598
if err != nil {
591599
return WrapErrorISE(err, "could not find current order by account id")
592600
}
@@ -595,7 +603,7 @@ func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, accountJWK *j
595603
}
596604

597605
order := orders[len(orders)-1]
598-
if err := db.CreateDpopToken(ctx, order, map[string]any(*dpop)); err != nil {
606+
if err := wireDB.CreateDpopToken(ctx, order, map[string]any(*dpop)); err != nil {
599607
return WrapErrorISE(err, "failed storing DPoP token")
600608
}
601609

acme/challenge_test.go

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -962,14 +962,16 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
962962
payload: payload,
963963
ctx: ctx,
964964
jwk: jwk,
965-
db: &MockDB{
966-
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
967-
assert.Equal(t, "chID", updch.ID)
968-
assert.Equal(t, "token", updch.Token)
969-
assert.Equal(t, StatusValid, updch.Status)
970-
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
971-
assert.Equal(t, string(valueBytes), updch.Value)
972-
return nil
965+
db: &MockWireDB{
966+
MockDB: MockDB{
967+
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
968+
assert.Equal(t, "chID", updch.ID)
969+
assert.Equal(t, "token", updch.Token)
970+
assert.Equal(t, StatusValid, updch.Status)
971+
assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type)
972+
assert.Equal(t, string(valueBytes), updch.Value)
973+
return nil
974+
},
973975
},
974976
MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) {
975977
assert.Equal(t, "accID", accountID)
@@ -1111,14 +1113,16 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
11111113
payload: payload,
11121114
ctx: ctx,
11131115
jwk: jwk,
1114-
db: &MockDB{
1115-
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
1116-
assert.Equal(t, "chID", updch.ID)
1117-
assert.Equal(t, "token", updch.Token)
1118-
assert.Equal(t, StatusValid, updch.Status)
1119-
assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type)
1120-
assert.Equal(t, string(valueBytes), updch.Value)
1121-
return nil
1116+
db: &MockWireDB{
1117+
MockDB: MockDB{
1118+
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
1119+
assert.Equal(t, "chID", updch.ID)
1120+
assert.Equal(t, "token", updch.Token)
1121+
assert.Equal(t, StatusValid, updch.Status)
1122+
assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type)
1123+
assert.Equal(t, string(valueBytes), updch.Value)
1124+
return nil
1125+
},
11221126
},
11231127
MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) {
11241128
assert.Equal(t, "accID", accountID)

0 commit comments

Comments
 (0)