Skip to content

Commit aa66bfc

Browse files
committed
Simplify message validation function
The linter complained about high complexity, so I extracted some parts of it into smaller functions.
1 parent db250a4 commit aa66bfc

File tree

3 files changed

+110
-20
lines changed

3 files changed

+110
-20
lines changed

rolling-shutter/chainobserver/db/keyper/extend.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ package database
22

33
import (
44
"github.com/ethereum/go-ethereum/common"
5+
"github.com/pkg/errors"
56

67
"github.com/shutter-network/rolling-shutter/rolling-shutter/shdb"
78
)
89

10+
// Contains checks if the given address is present in the KeyperSet.
11+
// It returns true if the address is found, otherwise false.
912
func (s *KeyperSet) Contains(address common.Address) bool {
1013
encodedAddress := shdb.EncodeAddress(address)
1114
for _, m := range s.Keypers {
@@ -15,3 +18,23 @@ func (s *KeyperSet) Contains(address common.Address) bool {
1518
}
1619
return false
1720
}
21+
22+
// GetSubset returns a subset of addresses from the KeyperSet based on the given indices.
23+
// The return value is ordered according to the order of the given indices. If indices contains
24+
// duplicates, the return value will do so as well. If at least one of the given indices is out of
25+
// range, an error is returned.
26+
func (s *KeyperSet) GetSubset(indices []uint64) ([]common.Address, error) {
27+
subset := []common.Address{}
28+
for _, i := range indices {
29+
if i >= uint64(len(s.Keypers)) {
30+
return nil, errors.Errorf("keyper index %d out of range (size %d)", i, len(s.Keypers))
31+
}
32+
addressStr := s.Keypers[i]
33+
address, err := shdb.DecodeAddress(addressStr)
34+
if err != nil {
35+
return nil, err
36+
}
37+
subset = append(subset, address)
38+
}
39+
return subset, nil
40+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package database
2+
3+
import (
4+
"testing"
5+
6+
"github.com/ethereum/go-ethereum/common"
7+
"gotest.tools/v3/assert"
8+
9+
"github.com/shutter-network/rolling-shutter/rolling-shutter/shdb"
10+
)
11+
12+
func makeTestKeyperSet() KeyperSet {
13+
return KeyperSet{
14+
KeyperConfigIndex: 0,
15+
ActivationBlockNumber: 0,
16+
Keypers: []string{
17+
shdb.EncodeAddress(common.HexToAddress("0x0000000000000000000000000000000000000000")),
18+
shdb.EncodeAddress(common.HexToAddress("0x5555555555555555555555555555555555555555")),
19+
shdb.EncodeAddress(common.HexToAddress("0xaAaAaAaaAaAaAaaAaAAAAAAAAaaaAaAaAaaAaaAa")),
20+
},
21+
Threshold: 2,
22+
}
23+
}
24+
25+
func TestKeyperSetContains(t *testing.T) {
26+
keyperSet := makeTestKeyperSet()
27+
addresses, err := shdb.DecodeAddresses(keyperSet.Keypers)
28+
assert.NilError(t, err)
29+
30+
for _, address := range addresses {
31+
assert.Assert(t, keyperSet.Contains(address))
32+
}
33+
assert.Assert(t, !keyperSet.Contains(common.HexToAddress("0xffffffffffffffffffffffffffffffffffffffff")))
34+
}
35+
36+
func TestKeyperSetSubset(t *testing.T) {
37+
keyperSet := makeTestKeyperSet()
38+
testCases := []struct {
39+
indices []uint64
40+
valid bool
41+
}{
42+
{indices: []uint64{0, 1, 2}, valid: true},
43+
{indices: []uint64{}, valid: true},
44+
{indices: []uint64{1, 0}, valid: true},
45+
{indices: []uint64{0, 0}, valid: true},
46+
{indices: []uint64{0, 0, 0, 0}, valid: true},
47+
{indices: []uint64{3}, valid: false},
48+
}
49+
50+
for _, tc := range testCases {
51+
subset, err := keyperSet.GetSubset(tc.indices)
52+
if tc.valid {
53+
assert.Assert(t, len(subset) == len(tc.indices))
54+
for _, i := range tc.indices {
55+
assert.Assert(t, shdb.EncodeAddress(subset[i]) == keyperSet.Keypers[tc.indices[i]])
56+
}
57+
} else {
58+
assert.Assert(t, err != nil)
59+
}
60+
}
61+
}

rolling-shutter/keyperimpl/gnosis/handlers.go

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"math"
66

7-
"github.com/ethereum/go-ethereum/common"
87
"github.com/jackc/pgx/v4"
98
"github.com/jackc/pgx/v4/pgxpool"
109
pubsub "github.com/libp2p/go-libp2p-pubsub"
@@ -186,6 +185,24 @@ func (h *DecryptionKeysHandler) MessagePrototypes() []p2pmsg.Message {
186185
return []p2pmsg.Message{&p2pmsg.DecryptionKeys{}}
187186
}
188187

188+
func validateSignerIndices(extra *p2pmsg.DecryptionKeys_Gnosis, n int) (pubsub.ValidationResult, error) {
189+
for i, signerIndex := range extra.Gnosis.SignerIndices {
190+
if i >= 1 {
191+
prevSignerIndex := extra.Gnosis.SignerIndices[i-1]
192+
if signerIndex == prevSignerIndex {
193+
return pubsub.ValidationReject, errors.New("duplicate signer index found")
194+
}
195+
if signerIndex < prevSignerIndex {
196+
return pubsub.ValidationReject, errors.New("signer indices not ordered")
197+
}
198+
}
199+
if signerIndex >= uint64(n) {
200+
return pubsub.ValidationReject, errors.New("signer index out of range")
201+
}
202+
}
203+
return pubsub.ValidationAccept, nil
204+
}
205+
189206
func (h *DecryptionKeysHandler) ValidateMessage(ctx context.Context, msg p2pmsg.Message) (pubsub.ValidationResult, error) {
190207
keys := msg.(*p2pmsg.DecryptionKeys)
191208
extra, ok := keys.Extra.(*p2pmsg.DecryptionKeys_Gnosis)
@@ -220,25 +237,14 @@ func (h *DecryptionKeysHandler) ValidateMessage(ctx context.Context, msg p2pmsg.
220237
if int32(len(extra.Gnosis.SignerIndices)) != keyperSet.Threshold {
221238
return pubsub.ValidationReject, errors.Errorf("expected %d signers, got %d", keyperSet.Threshold, len(extra.Gnosis.SignerIndices))
222239
}
223-
signers := []common.Address{}
224-
for i, signerIndex := range extra.Gnosis.SignerIndices {
225-
if i >= 1 {
226-
prevSignerIndex := extra.Gnosis.SignerIndices[i-1]
227-
if signerIndex == prevSignerIndex {
228-
return pubsub.ValidationReject, errors.New("duplicate signer index found")
229-
}
230-
if signerIndex < prevSignerIndex {
231-
return pubsub.ValidationReject, errors.New("signer indices not ordered")
232-
}
233-
}
234-
if signerIndex >= uint64(len(keyperSet.Keypers)) {
235-
return pubsub.ValidationReject, errors.New("signer index out of range")
236-
}
237-
signer, err := shdb.DecodeAddress(keyperSet.Keypers[signerIndex])
238-
if err != nil {
239-
return pubsub.ValidationReject, errors.Wrap(err, "failed to decode signer address")
240-
}
241-
signers = append(signers, signer)
240+
241+
res, err := validateSignerIndices(extra, len(keyperSet.Keypers))
242+
if res != pubsub.ValidationAccept {
243+
return res, err
244+
}
245+
signers, err := keyperSet.GetSubset(extra.Gnosis.SignerIndices)
246+
if err != nil {
247+
return pubsub.ValidationReject, err
242248
}
243249

244250
identityPreimages := []identitypreimage.IdentityPreimage{}

0 commit comments

Comments
 (0)