Skip to content

Commit e17e466

Browse files
committed
Intercept outgoing messages
This allows modifying fields, in particular the extra one to add application specific data.
1 parent d92c56d commit e17e466

File tree

2 files changed

+198
-0
lines changed

2 files changed

+198
-0
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package gnosis
2+
3+
import (
4+
"context"
5+
6+
pubsub "github.com/libp2p/go-libp2p-pubsub"
7+
"github.com/pkg/errors"
8+
9+
"github.com/shutter-network/rolling-shutter/rolling-shutter/p2pmsg"
10+
)
11+
12+
type DecryptionKeySharesHandler struct{}
13+
14+
func (h *DecryptionKeySharesHandler) MessagePrototypes() []p2pmsg.Message {
15+
return []p2pmsg.Message{&p2pmsg.DecryptionKeyShares{}}
16+
}
17+
18+
func (h *DecryptionKeySharesHandler) ValidateMessage(_ context.Context, msg p2pmsg.Message) (pubsub.ValidationResult, error) {
19+
keyShares := msg.(*p2pmsg.DecryptionKeyShares)
20+
_, ok := keyShares.Extra.(*p2pmsg.DecryptionKeyShares_Gnosis)
21+
if !ok {
22+
return pubsub.ValidationReject, errors.Errorf("unexpected extra type %T, expected Gnosis", keyShares.Extra)
23+
}
24+
return pubsub.ValidationAccept, nil
25+
}
26+
27+
func (h *DecryptionKeySharesHandler) HandleMessage(_ context.Context, _ p2pmsg.Message) ([]p2pmsg.Message, error) {
28+
return []p2pmsg.Message{}, nil
29+
}
30+
31+
type DecryptionKeysHandler struct{}
32+
33+
func (h *DecryptionKeysHandler) MessagePrototypes() []p2pmsg.Message {
34+
return []p2pmsg.Message{&p2pmsg.DecryptionKeys{}}
35+
}
36+
37+
func (h *DecryptionKeysHandler) ValidateMessage(_ context.Context, msg p2pmsg.Message) (pubsub.ValidationResult, error) {
38+
keys := msg.(*p2pmsg.DecryptionKeys)
39+
_, ok := keys.Extra.(*p2pmsg.DecryptionKeys_Gnosis)
40+
if !ok {
41+
return pubsub.ValidationReject, errors.Errorf("unexpected extra type %T, expected Gnosis", keys.Extra)
42+
}
43+
return pubsub.ValidationAccept, nil
44+
}
45+
46+
func (h *DecryptionKeysHandler) HandleMessage(_ context.Context, _ p2pmsg.Message) ([]p2pmsg.Message, error) {
47+
return []p2pmsg.Message{}, nil
48+
}
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
package gnosis
2+
3+
import (
4+
"bytes"
5+
"context"
6+
7+
"github.com/jackc/pgx/v4"
8+
"github.com/jackc/pgx/v4/pgxpool"
9+
pubsub "github.com/libp2p/go-libp2p-pubsub"
10+
"github.com/pkg/errors"
11+
"github.com/rs/zerolog/log"
12+
"google.golang.org/protobuf/proto"
13+
14+
"github.com/shutter-network/rolling-shutter/rolling-shutter/keyperimpl/gnosis/database"
15+
"github.com/shutter-network/rolling-shutter/rolling-shutter/medley/identitypreimage"
16+
"github.com/shutter-network/rolling-shutter/rolling-shutter/medley/retry"
17+
"github.com/shutter-network/rolling-shutter/rolling-shutter/medley/service"
18+
"github.com/shutter-network/rolling-shutter/rolling-shutter/p2p"
19+
"github.com/shutter-network/rolling-shutter/rolling-shutter/p2pmsg"
20+
)
21+
22+
type MessagingMiddleware struct {
23+
messaging p2p.Messaging
24+
dbpool *pgxpool.Pool
25+
}
26+
27+
type WrappedMessageHandler struct {
28+
handler p2p.MessageHandler
29+
middleware *MessagingMiddleware
30+
}
31+
32+
func (h *WrappedMessageHandler) MessagePrototypes() []p2pmsg.Message {
33+
return h.handler.MessagePrototypes()
34+
}
35+
36+
func (h *WrappedMessageHandler) ValidateMessage(ctx context.Context, msg p2pmsg.Message) (pubsub.ValidationResult, error) {
37+
return h.handler.ValidateMessage(ctx, msg)
38+
}
39+
40+
func (h *WrappedMessageHandler) HandleMessage(ctx context.Context, msg p2pmsg.Message) ([]p2pmsg.Message, error) {
41+
msgs, err := h.handler.HandleMessage(ctx, msg)
42+
if err != nil {
43+
return []p2pmsg.Message{}, err
44+
}
45+
replacedMsgs := []p2pmsg.Message{}
46+
for _, msg := range msgs {
47+
replacedMsg, err := h.middleware.interceptMessage(ctx, msg)
48+
if err != nil {
49+
return []p2pmsg.Message{}, err
50+
}
51+
replacedMsgs = append(replacedMsgs, replacedMsg)
52+
}
53+
return replacedMsgs, nil
54+
}
55+
56+
func NewMessagingMiddleware(messaging p2p.Messaging, dbpool *pgxpool.Pool) *MessagingMiddleware {
57+
return &MessagingMiddleware{messaging: messaging, dbpool: dbpool}
58+
}
59+
60+
func (i *MessagingMiddleware) Start(_ context.Context, runner service.Runner) error {
61+
return runner.StartService(i.messaging)
62+
}
63+
64+
func (i *MessagingMiddleware) interceptMessage(ctx context.Context, msg p2pmsg.Message) (p2pmsg.Message, error) {
65+
switch msg := msg.(type) {
66+
case *p2pmsg.DecryptionKeyShares:
67+
return i.interceptDecryptionKeyShares(ctx, msg)
68+
case *p2pmsg.DecryptionKeys:
69+
return i.interceptDecryptionKeys(ctx, msg)
70+
default:
71+
return msg, nil
72+
}
73+
}
74+
75+
func (i *MessagingMiddleware) SendMessage(ctx context.Context, msg p2pmsg.Message, opts ...retry.Option) error {
76+
msg, err := i.interceptMessage(ctx, msg)
77+
if err != nil {
78+
return err
79+
}
80+
if msg != nil {
81+
return i.messaging.SendMessage(ctx, msg, opts...)
82+
}
83+
return nil
84+
}
85+
86+
func (i *MessagingMiddleware) AddValidator(ctx p2p.ValidatorFunc, protos ...p2pmsg.Message) {
87+
i.messaging.AddValidator(ctx, protos...)
88+
}
89+
90+
func (i *MessagingMiddleware) AddMessageHandler(mhs ...p2p.MessageHandler) {
91+
for _, mh := range mhs {
92+
wmh := &WrappedMessageHandler{handler: mh, middleware: i}
93+
i.messaging.AddMessageHandler(wmh)
94+
}
95+
}
96+
97+
func (i *MessagingMiddleware) interceptDecryptionKeyShares(
98+
ctx context.Context,
99+
originalMsg *p2pmsg.DecryptionKeyShares,
100+
) (p2pmsg.Message, error) {
101+
queries := database.New(i.dbpool)
102+
currentDecryptionTrigger, err := queries.GetCurrentDecryptionTrigger(ctx, int64(originalMsg.Eon))
103+
if err == pgx.ErrNoRows {
104+
log.Warn().
105+
Uint64("eon", originalMsg.Eon).
106+
Msg("intercepted decryption key shares message with unknown corresponding decryption trigger")
107+
return nil, nil
108+
} else if err != nil {
109+
return nil, errors.Wrapf(err, "failed to get current decryption trigger for eon %d", originalMsg.Eon)
110+
}
111+
identityPreimges := []identitypreimage.IdentityPreimage{}
112+
for _, share := range originalMsg.Shares {
113+
identityPreimges = append(identityPreimges, identitypreimage.IdentityPreimage(share.EpochID))
114+
}
115+
identitiesHash := computeIdentitiesHash(identityPreimges)
116+
if !bytes.Equal(identitiesHash, currentDecryptionTrigger.IdentitiesHash) {
117+
log.Warn().
118+
Uint64("eon", originalMsg.Eon).
119+
Hex("expectedIdentitiesHash", currentDecryptionTrigger.IdentitiesHash).
120+
Hex("actualIdentitiesHash", identitiesHash).
121+
Msg("intercepted decryption key shares message with unexpected identities hash")
122+
return nil, nil
123+
}
124+
125+
msg := proto.Clone(originalMsg).(*p2pmsg.DecryptionKeyShares)
126+
msg.Extra = &p2pmsg.DecryptionKeyShares_Gnosis{
127+
Gnosis: &p2pmsg.GnosisDecryptionKeySharesExtra{
128+
Slot: uint64(currentDecryptionTrigger.Block),
129+
TxPointer: uint64(currentDecryptionTrigger.TxPointer),
130+
Signature: []byte{},
131+
},
132+
}
133+
return msg, nil
134+
}
135+
136+
func (i *MessagingMiddleware) interceptDecryptionKeys(
137+
_ context.Context,
138+
originalMsg *p2pmsg.DecryptionKeys,
139+
) (p2pmsg.Message, error) {
140+
msg := proto.Clone(originalMsg).(*p2pmsg.DecryptionKeys)
141+
msg.Extra = &p2pmsg.DecryptionKeys_Gnosis{
142+
Gnosis: &p2pmsg.GnosisDecryptionKeysExtra{
143+
Slot: 0,
144+
TxPointer: 0,
145+
SignerIndices: []uint64{},
146+
Signatures: [][]byte{},
147+
},
148+
}
149+
return msg, nil
150+
}

0 commit comments

Comments
 (0)