Skip to content

Commit 70bb73d

Browse files
authored
fix: uses unimplemented server if impl is nil (#1690)
1 parent 5c5be51 commit 70bb73d

File tree

2 files changed

+62
-6
lines changed

2 files changed

+62
-6
lines changed

pkg/loop/internal/core/services/gateway/gateway_connector.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ func (s GatewayConnectorServer) AddHandler(ctx context.Context, req *pb.AddHandl
127127
return nil, fmt.Errorf("failed to dial handler: %d: %w", req.HandlerId, err)
128128
}
129129
client := NewGatewayConnectorHandlerClient(conn)
130-
err = s.impl.AddHandler(ctx, req.Methods, client)
130+
err = s.getImpl().AddHandler(ctx, req.Methods, client)
131131
if err != nil {
132132
return nil, fmt.Errorf("failed to add handler: %d: %w", req.HandlerId, err)
133133
}
@@ -140,38 +140,48 @@ func (s GatewayConnectorServer) SendToGateway(ctx context.Context, req *pb.SendM
140140
if err != nil {
141141
return nil, fmt.Errorf("failed to decode response: %w", err)
142142
}
143-
if err := s.impl.SendToGateway(ctx, req.GatewayId, &resp); err != nil {
143+
if err := s.getImpl().SendToGateway(ctx, req.GatewayId, &resp); err != nil {
144144
return nil, fmt.Errorf("failed to send message to gateway: %s: %w", req.GatewayId, err)
145145
}
146146
return &emptypb.Empty{}, nil
147147
}
148+
148149
func (s GatewayConnectorServer) SignMessage(ctx context.Context, req *pb.SignMessageRequest) (*pb.SignMessageReply, error) {
149-
signature, err := s.impl.SignMessage(ctx, req.Message)
150+
signature, err := s.getImpl().SignMessage(ctx, req.Message)
150151
if err != nil {
151152
return nil, fmt.Errorf("failed to sign message: %w", err)
152153
}
153154
return &pb.SignMessageReply{
154155
Signature: signature,
155156
}, nil
156157
}
158+
157159
func (s GatewayConnectorServer) GatewayIDs(ctx context.Context, _ *emptypb.Empty) (*pb.GatewayIDsReply, error) {
158-
gatewayIDs, err := s.impl.GatewayIDs(ctx)
160+
gatewayIDs, err := s.getImpl().GatewayIDs(ctx)
159161
if err != nil {
160162
return nil, fmt.Errorf("failed to get gateway IDs: %w", err)
161163
}
162164
return &pb.GatewayIDsReply{GatewayIds: gatewayIDs}, nil
163165
}
164166

165167
func (s GatewayConnectorServer) DonID(ctx context.Context, _ *emptypb.Empty) (*pb.DonIDReply, error) {
166-
donID, err := s.impl.DonID(ctx)
168+
donID, err := s.getImpl().DonID(ctx)
167169
if err != nil {
168170
return nil, fmt.Errorf("failed to get DON ID: %w", err)
169171
}
170172
return &pb.DonIDReply{DonId: donID}, nil
171173
}
174+
172175
func (s GatewayConnectorServer) AwaitConnection(ctx context.Context, req *pb.GatewayIDRequest) (*emptypb.Empty, error) {
173-
if err := s.impl.AwaitConnection(ctx, req.GatewayId); err != nil {
176+
if err := s.getImpl().AwaitConnection(ctx, req.GatewayId); err != nil {
174177
return nil, fmt.Errorf("failed to await connection to gateway: %s: %w", req.GatewayId, err)
175178
}
176179
return &emptypb.Empty{}, nil
177180
}
181+
182+
func (s GatewayConnectorServer) getImpl() core.GatewayConnector {
183+
if s.impl == nil {
184+
return &core.UnimplementedGatewayConnector{}
185+
}
186+
return s.impl
187+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package gateway_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/smartcontractkit/chainlink-common/pkg/logger"
7+
"github.com/smartcontractkit/chainlink-common/pkg/loop/internal/core/services/gateway"
8+
loopnet "github.com/smartcontractkit/chainlink-common/pkg/loop/internal/net"
9+
loopnettest "github.com/smartcontractkit/chainlink-common/pkg/loop/internal/net/test"
10+
pb "github.com/smartcontractkit/chainlink-common/pkg/loop/internal/pb/gatewayconnector"
11+
"github.com/smartcontractkit/chainlink-common/pkg/types/core"
12+
13+
"github.com/stretchr/testify/require"
14+
)
15+
16+
func Test_GatewayConnectorServer(t *testing.T) {
17+
t.Run("calling AddHandler with a nil connector client does not panic", func(t *testing.T) {
18+
var (
19+
lggr = logger.Test(t)
20+
ctx = t.Context()
21+
handlerID = uint32(0)
22+
broker = &loopnettest.Broker{T: t}
23+
brokerCfg = loopnet.BrokerConfig{Logger: lggr, StopCh: make(chan struct{})}
24+
brokerExt = &loopnet.BrokerExt{
25+
Broker: broker,
26+
BrokerConfig: brokerCfg,
27+
}
28+
)
29+
30+
// allocate a listener for the handler
31+
_, err := broker.Accept(handlerID)
32+
require.NoError(t, err)
33+
34+
// create instance of connector server with empty GatewayConnector
35+
var gc core.GatewayConnector
36+
gcs := gateway.NewGatewayConnectorServer(brokerExt, gc)
37+
38+
// assert that call does not panic, yet errors
39+
res, err := gcs.AddHandler(ctx, &pb.AddHandlerRequest{
40+
HandlerId: handlerID,
41+
})
42+
require.Error(t, err)
43+
require.ErrorContains(t, err, "not implemented")
44+
require.Nil(t, res)
45+
})
46+
}

0 commit comments

Comments
 (0)