@@ -3,8 +3,8 @@ package epochkghandler
3
3
import (
4
4
"bytes"
5
5
"context"
6
- "math"
7
6
7
+ lru "github.com/hashicorp/golang-lru/v2"
8
8
"github.com/jackc/pgx/v4"
9
9
"github.com/jackc/pgx/v4/pgxpool"
10
10
pubsub "github.com/libp2p/go-libp2p-pubsub"
@@ -13,6 +13,7 @@ import (
13
13
"github.com/shutter-network/shutter/shlib/shcrypto"
14
14
15
15
"github.com/shutter-network/rolling-shutter/rolling-shutter/keyper/database"
16
+ "github.com/shutter-network/rolling-shutter/rolling-shutter/medley"
16
17
"github.com/shutter-network/rolling-shutter/rolling-shutter/p2p"
17
18
"github.com/shutter-network/rolling-shutter/rolling-shutter/p2pmsg"
18
19
"github.com/shutter-network/rolling-shutter/rolling-shutter/shdb"
@@ -21,12 +22,16 @@ import (
21
22
const MaxNumKeysPerMessage = 128
22
23
23
24
func NewDecryptionKeyHandler (config Config , dbpool * pgxpool.Pool ) p2p.MessageHandler {
24
- return & DecryptionKeyHandler {config : config , dbpool : dbpool }
25
+ // Not catching the error as it only can happen if non-positive size was applied
26
+ cache , _ := lru.New [shcrypto.EpochSecretKey , []byte ](1024 )
27
+ return & DecryptionKeyHandler {config : config , dbpool : dbpool , cache : cache }
25
28
}
26
29
27
30
type DecryptionKeyHandler struct {
28
31
config Config
29
32
dbpool * pgxpool.Pool
33
+ // keep 1024 verified keys in Cache to skip additional verifications
34
+ cache * lru.Cache [shcrypto.EpochSecretKey , []byte ]
30
35
}
31
36
32
37
func (* DecryptionKeyHandler ) MessagePrototypes () []p2pmsg.Message {
@@ -39,23 +44,23 @@ func (handler *DecryptionKeyHandler) ValidateMessage(ctx context.Context, msg p2
39
44
return pubsub .ValidationReject ,
40
45
errors .Errorf ("instance ID mismatch (want=%d, have=%d)" , handler .config .GetInstanceID (), key .GetInstanceID ())
41
46
}
42
- if key .Eon > math .MaxInt64 {
43
- return pubsub .ValidationReject , errors .Errorf ("eon %d overflows int64" , key .Eon )
47
+ eon , err := medley .Uint64ToInt64Safe (key .Eon )
48
+ if err != nil {
49
+ return pubsub .ValidationReject , errors .Wrapf (err , "overflow error while converting eon to int64 %d" , eon )
44
50
}
45
-
46
- dkgResultDB , err := database .New (handler .dbpool ).GetDKGResultForKeyperConfigIndex (ctx , int64 (key .Eon ))
47
- if err == pgx .ErrNoRows {
48
- return pubsub .ValidationReject , errors .Errorf ("no DKG result found for eon %d" , key .Eon )
51
+ dkgResultDB , err := database .New (handler .dbpool ).GetDKGResultForKeyperConfigIndex (ctx , eon )
52
+ if errors .Is (err , pgx .ErrNoRows ) {
53
+ return pubsub .ValidationReject , errors .Errorf ("no DKG result found for eon %d" , eon )
49
54
}
50
55
if err != nil {
51
- return pubsub .ValidationReject , errors .Wrapf (err , "failed to get dkg result for eon %d from db" , key . Eon )
56
+ return pubsub .ValidationReject , errors .Wrapf (err , "failed to get dkg result for eon %d from db" , eon )
52
57
}
53
58
if ! dkgResultDB .Success {
54
- return pubsub .ValidationReject , errors .Errorf ("no successful DKG result found for eon %d" , key . Eon )
59
+ return pubsub .ValidationReject , errors .Errorf ("no successful DKG result found for eon %d" , eon )
55
60
}
56
61
pureDKGResult , err := shdb .DecodePureDKGResult (dkgResultDB .PureResult )
57
62
if err != nil {
58
- return pubsub .ValidationReject , errors .Wrapf (err , "error while decoding pure DKG result for eon %d" , key . Eon )
63
+ return pubsub .ValidationReject , errors .Wrapf (err , "error while decoding pure DKG result for eon %d" , eon )
59
64
}
60
65
61
66
if len (key .Keys ) == 0 {
@@ -64,19 +69,26 @@ func (handler *DecryptionKeyHandler) ValidateMessage(ctx context.Context, msg p2
64
69
if len (key .Keys ) > MaxNumKeysPerMessage {
65
70
return pubsub .ValidationReject , errors .Errorf ("too many keys in message (%d > %d)" , len (key .Keys ), MaxNumKeysPerMessage )
66
71
}
72
+
67
73
for i , k := range key .Keys {
68
74
epochSecretKey , err := k .GetEpochSecretKey ()
69
75
if err != nil {
70
76
return pubsub .ValidationReject , err
71
77
}
78
+ identity , exists := handler .cache .Get (* epochSecretKey )
79
+ if exists {
80
+ if bytes .Equal (k .Identity , identity ) {
81
+ continue
82
+ }
83
+ return pubsub .ValidationReject , errors .Errorf ("epoch secret key for identity %x is not valid" , k .Identity )
84
+ }
72
85
ok , err := shcrypto .VerifyEpochSecretKey (epochSecretKey , pureDKGResult .PublicKey , k .Identity )
73
86
if err != nil {
74
87
return pubsub .ValidationReject , errors .Wrapf (err , "error while checking epoch secret key for identity %x" , k .Identity )
75
88
}
76
89
if ! ok {
77
90
return pubsub .ValidationReject , errors .Errorf ("epoch secret key for identity %x is not valid" , k .Identity )
78
91
}
79
-
80
92
if i > 0 && bytes .Compare (k .Identity , key .Keys [i - 1 ].Identity ) < 0 {
81
93
return pubsub .ValidationReject , errors .Errorf ("keys not ordered" )
82
94
}
@@ -87,7 +99,15 @@ func (handler *DecryptionKeyHandler) ValidateMessage(ctx context.Context, msg p2
87
99
func (handler * DecryptionKeyHandler ) HandleMessage (ctx context.Context , msg p2pmsg.Message ) ([]p2pmsg.Message , error ) {
88
100
metricsEpochKGDecryptionKeysReceived .Inc ()
89
101
key := msg .(* p2pmsg.DecryptionKeys )
90
- // Insert the key into the db. We assume that it's valid as it already passed the libp2p
91
- // validator.
102
+ // We assume that it's valid as it already passed the libp2p validator.
103
+ // Insert the key into the cache.
104
+ for _ , k := range key .Keys {
105
+ epochSecretKey , err := k .GetEpochSecretKey ()
106
+ if err != nil {
107
+ return nil , err
108
+ }
109
+ handler .cache .Add (* epochSecretKey , k .Identity )
110
+ }
111
+ // Insert the key into the db.
92
112
return nil , database .New (handler .dbpool ).InsertDecryptionKeysMsg (ctx , key )
93
113
}
0 commit comments