Skip to content

Commit 7df24b0

Browse files
authored
Reset connection after write stream fails (#65)
1 parent 4d5781f commit 7df24b0

File tree

3 files changed

+56
-21
lines changed

3 files changed

+56
-21
lines changed

comm/p2p/libp2p.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func NewCommunication(h host.Host, protocolID protocol.ID) Libp2pCommunication {
3939
h: h,
4040
protocolID: protocolID,
4141
logger: logger,
42-
streamManager: NewStreamManager(),
42+
streamManager: NewStreamManager(h),
4343
}
4444

4545
// start processing incoming messages
@@ -179,21 +179,27 @@ func (c Libp2pCommunication) sendMessage(
179179
}
180180

181181
var stream network.Stream
182-
stream, err = c.streamManager.Stream(sessionID, to)
182+
stream, err = c.streamManager.Stream(sessionID, to, c.protocolID)
183183
if err != nil {
184-
// try to open the stream again if it failed the first time
185-
stream, err = c.h.NewStream(context.TODO(), to, c.protocolID)
186-
if err != nil {
187-
return err
188-
}
189-
c.streamManager.AddStream(sessionID, to, stream)
184+
return err
190185
}
191186

192187
err = WriteStream(msg, bufio.NewWriterSize(stream, defaultBufferSize))
193188
if err != nil {
194189
c.logger.Error().Str("To", to.String()).Err(err).Msg("unable to send message")
190+
c.streamManager.ReleaseStreams(sessionID)
191+
192+
stream, err = c.streamManager.Stream(sessionID, to, c.protocolID)
193+
if err != nil {
194+
return err
195+
}
196+
197+
err = WriteStream(msg, bufio.NewWriterSize(stream, defaultBufferSize))
198+
}
199+
if err != nil {
195200
return err
196201
}
202+
197203
c.logger.Trace().Str(
198204
"To", to.String()).Str(
199205
"MsgType", msgType.String()).Str(

comm/p2p/manager.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
package p2p
55

66
import (
7-
"fmt"
7+
"context"
88
"sync"
99

10+
"github.com/libp2p/go-libp2p/core/host"
1011
"github.com/libp2p/go-libp2p/core/network"
1112
"github.com/libp2p/go-libp2p/core/peer"
13+
"github.com/libp2p/go-libp2p/core/protocol"
1214
"github.com/rs/zerolog/log"
1315
)
1416

@@ -18,13 +20,15 @@ import (
1820
type StreamManager struct {
1921
streamsBySessionID map[string]map[peer.ID]network.Stream
2022
streamLocker *sync.Mutex
23+
host host.Host
2124
}
2225

2326
// NewStreamManager creates new StreamManager
24-
func NewStreamManager() *StreamManager {
27+
func NewStreamManager(host host.Host) *StreamManager {
2528
return &StreamManager{
2629
streamsBySessionID: make(map[string]map[peer.ID]network.Stream),
2730
streamLocker: &sync.Mutex{},
31+
host: host,
2832
}
2933
}
3034

@@ -39,6 +43,10 @@ func (sm *StreamManager) ReleaseStreams(sessionID string) {
3943
}
4044

4145
for peer, stream := range streams {
46+
if stream.Conn() != nil {
47+
_ = stream.Conn().Close()
48+
}
49+
4250
err := stream.Close()
4351
if err != nil {
4452
log.Err(err).Msgf("Cannot close stream to peer %s", peer.String())
@@ -67,14 +75,21 @@ func (sm *StreamManager) AddStream(sessionID string, peerID peer.ID, stream netw
6775
}
6876

6977
// Stream fetches stream by peer and session ID
70-
func (sm *StreamManager) Stream(sessionID string, peerID peer.ID) (network.Stream, error) {
78+
func (sm *StreamManager) Stream(sessionID string, peerID peer.ID, protocolID protocol.ID) (network.Stream, error) {
7179
sm.streamLocker.Lock()
72-
defer sm.streamLocker.Unlock()
7380

7481
stream, ok := sm.streamsBySessionID[sessionID][peerID]
7582
if !ok {
76-
return nil, fmt.Errorf("no stream for peerID %s", peerID)
83+
stream, err := sm.host.NewStream(context.TODO(), peerID, protocolID)
84+
if err != nil {
85+
return nil, err
86+
}
87+
88+
sm.streamLocker.Unlock()
89+
sm.AddStream(sessionID, peerID, stream)
90+
return stream, nil
7791
}
7892

93+
sm.streamLocker.Unlock()
7994
return stream, nil
8095
}

comm/p2p/manager_test.go

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ import (
77
"testing"
88

99
"github.com/libp2p/go-libp2p/core/peer"
10+
"github.com/libp2p/go-libp2p/core/protocol"
1011
"github.com/sprintertech/sprinter-signing/comm/p2p"
1112

13+
mock_host "github.com/sprintertech/sprinter-signing/comm/p2p/mock/host"
1214
mock_network "github.com/sprintertech/sprinter-signing/comm/p2p/mock/stream"
1315
"github.com/stretchr/testify/suite"
1416
"go.uber.org/mock/gomock"
@@ -17,6 +19,7 @@ import (
1719
type StreamManagerTestSuite struct {
1820
suite.Suite
1921
mockController *gomock.Controller
22+
mockHost *mock_host.MockHost
2023
}
2124

2225
func TestRunStreamManagerTestSuite(t *testing.T) {
@@ -25,14 +28,21 @@ func TestRunStreamManagerTestSuite(t *testing.T) {
2528

2629
func (s *StreamManagerTestSuite) SetupTest() {
2730
s.mockController = gomock.NewController(s.T())
31+
s.mockHost = mock_host.NewMockHost(s.mockController)
2832
}
2933

3034
func (s *StreamManagerTestSuite) Test_ManagingSubscriptions_Success() {
31-
streamManager := p2p.NewStreamManager()
35+
streamManager := p2p.NewStreamManager(s.mockHost)
36+
37+
mockConn := mock_network.NewMockConn(s.mockController)
38+
mockConn.EXPECT().Close().Return(nil).Times(2)
3239

3340
stream1 := mock_network.NewMockStream(s.mockController)
41+
stream1.EXPECT().Conn().Return(mockConn).AnyTimes()
3442
stream2 := mock_network.NewMockStream(s.mockController)
43+
stream2.EXPECT().Conn().Return(mockConn).AnyTimes()
3544
stream3 := mock_network.NewMockStream(s.mockController)
45+
stream3.EXPECT().Conn().Return(mockConn).AnyTimes()
3646

3747
peerID1, _ := peer.Decode("QmcW3oMdSqoEcjbyd51auqC23vhKX6BqfcZcY2HJ3sKAZR")
3848
peerID2, _ := peer.Decode("QmZHPnN3CKiTAp8VaJqszbf8m7v4mPh15M421KpVdYHF54")
@@ -49,36 +59,40 @@ func (s *StreamManagerTestSuite) Test_ManagingSubscriptions_Success() {
4959
}
5060

5161
func (s *StreamManagerTestSuite) Test_FetchStream_NoStream() {
52-
streamManager := p2p.NewStreamManager()
62+
streamManager := p2p.NewStreamManager(s.mockHost)
63+
64+
expectedStream := mock_network.NewMockStream(s.mockController)
65+
s.mockHost.EXPECT().NewStream(gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedStream, nil)
5366

54-
_, err := streamManager.Stream("1", peer.ID(""))
67+
stream, err := streamManager.Stream("1", peer.ID(""), protocol.ID(""))
5568

56-
s.NotNil(err)
69+
s.Nil(err)
70+
s.Equal(stream, expectedStream)
5771
}
5872

5973
func (s *StreamManagerTestSuite) Test_FetchStream_ValidStream() {
60-
streamManager := p2p.NewStreamManager()
74+
streamManager := p2p.NewStreamManager(s.mockHost)
6175

6276
stream := mock_network.NewMockStream(s.mockController)
6377
peerID1, _ := peer.Decode("QmcW3oMdSqoEcjbyd51auqC23vhKX6BqfcZcY2HJ3sKAZR")
6478
streamManager.AddStream("1", peerID1, stream)
6579

66-
expectedStream, err := streamManager.Stream("1", peerID1)
80+
expectedStream, err := streamManager.Stream("1", peerID1, protocol.ID(""))
6781

6882
s.Nil(err)
6983
s.Equal(stream, expectedStream)
7084
}
7185

7286
func (s *StreamManagerTestSuite) Test_AddStream_IgnoresExistingPeer() {
73-
streamManager := p2p.NewStreamManager()
87+
streamManager := p2p.NewStreamManager(s.mockHost)
7488

7589
stream1 := mock_network.NewMockStream(s.mockController)
7690
stream2 := mock_network.NewMockStream(s.mockController)
7791
peerID1, _ := peer.Decode("QmcW3oMdSqoEcjbyd51auqC23vhKX6BqfcZcY2HJ3sKAZR")
7892
streamManager.AddStream("1", peerID1, stream1)
7993
streamManager.AddStream("1", peerID1, stream2)
8094

81-
expectedStream, err := streamManager.Stream("1", peerID1)
95+
expectedStream, err := streamManager.Stream("1", peerID1, protocol.ID(""))
8296

8397
s.Nil(err)
8498
s.Equal(stream1, expectedStream)

0 commit comments

Comments
 (0)