Skip to content

Commit 80e689e

Browse files
fix: LZReceiver implement full OApp (SC-1124, SC-1125) (#40)
* save * fix: review * fix: review * fix: comment
1 parent f20db48 commit 80e689e

File tree

3 files changed

+185
-51
lines changed

3 files changed

+185
-51
lines changed

src/receivers/LZReceiver.sol

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
11
// SPDX-License-Identifier: AGPL-3.0-or-later
22
pragma solidity ^0.8.0;
33

4-
import { Address } from "lib/openzeppelin-contracts/contracts/utils/Address.sol";
4+
import { Address } from "openzeppelin-contracts/contracts/utils/Address.sol";
5+
import { Ownable } from "openzeppelin-contracts/contracts/access/Ownable.sol";
56

6-
struct Origin {
7-
uint32 srcEid;
8-
bytes32 sender;
9-
uint64 nonce;
10-
}
7+
import { OApp, Origin } from "layerzerolabs/oapp-evm/contracts/oapp/OApp.sol";
118

129
/**
1310
* @title LZReceiver
1411
* @notice Receive messages from LayerZero-style bridge.
1512
*/
16-
contract LZReceiver {
13+
contract LZReceiver is OApp {
1714

1815
using Address for address;
1916

20-
address public immutable destinationEndpoint;
2117
address public immutable target;
2218

2319
uint32 public immutable srcEid;
@@ -28,30 +24,34 @@ contract LZReceiver {
2824
address _destinationEndpoint,
2925
uint32 _srcEid,
3026
bytes32 _sourceAuthority,
31-
address _target
32-
) {
33-
destinationEndpoint = _destinationEndpoint;
34-
target = _target;
35-
sourceAuthority = _sourceAuthority;
36-
srcEid = _srcEid;
27+
address _target,
28+
address _delegate,
29+
address _owner
30+
) OApp(_destinationEndpoint, _delegate) Ownable(_owner) {
31+
target = _target;
32+
sourceAuthority = _sourceAuthority;
33+
srcEid = _srcEid;
34+
35+
_setPeer(_srcEid, _sourceAuthority);
3736
}
3837

39-
function lzReceive(
38+
function _lzReceive(
4039
Origin calldata _origin,
4140
bytes32, // _guid
4241
bytes calldata _message,
43-
address, // _executor,
42+
address, // _executor
4443
bytes calldata // _extraData
45-
) external {
46-
require(msg.sender == destinationEndpoint, "LZReceiver/invalid-sender");
44+
) internal override {
4745
require(_origin.srcEid == srcEid, "LZReceiver/invalid-srcEid");
4846
require(_origin.sender == sourceAuthority, "LZReceiver/invalid-sourceAuthority");
4947

50-
target.functionCall(_message);
48+
target.functionCallWithValue(_message, msg.value);
5149
}
5250

53-
function allowInitializePath(Origin calldata origin) public view returns (bool) {
54-
return origin.srcEid == srcEid && origin.sender == sourceAuthority;
51+
function allowInitializePath(Origin calldata origin) public view override returns (bool) {
52+
return super.allowInitializePath(origin)
53+
&& origin.srcEid == srcEid
54+
&& origin.sender == sourceAuthority;
5555
}
5656

5757
}

test/LZIntegration.t.sol

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,19 @@ contract LZIntegrationTest is IntegrationBaseTest {
2626
Domain destination2;
2727
Bridge bridge2;
2828

29-
function test_invalidSender() public {
29+
error NoPeer(uint32 eid);
30+
error OnlyEndpoint(address addr);
31+
error OnlyPeer(uint32 eid, bytes32 sender);
32+
33+
function test_invalidEndpoint() public {
3034
destinationEndpointId = LZForwarder.ENDPOINT_ID_BASE;
3135
destinationEndpoint = LZForwarder.ENDPOINT_BASE;
3236
initBaseContracts(getChain("base").createFork());
3337

3438
destination.selectFork();
3539

3640
vm.prank(randomAddress);
37-
vm.expectRevert("LZReceiver/invalid-sender");
41+
vm.expectRevert(abi.encodeWithSelector(OnlyEndpoint.selector, randomAddress));
3842
LZReceiver(destinationReceiver).lzReceive(
3943
Origin({
4044
srcEid: sourceEndpointId,
@@ -48,13 +52,61 @@ contract LZIntegrationTest is IntegrationBaseTest {
4852
);
4953
}
5054

55+
function test_lzReceive_revertsNoPeer() public {
56+
destinationEndpointId = LZForwarder.ENDPOINT_ID_BASE;
57+
destinationEndpoint = LZForwarder.ENDPOINT_BASE;
58+
initBaseContracts(getChain("base").createFork());
59+
60+
destination.selectFork();
61+
62+
vm.prank(bridge.destinationCrossChainMessenger);
63+
vm.expectRevert(abi.encodeWithSelector(NoPeer.selector, 0));
64+
LZReceiver(destinationReceiver).lzReceive(
65+
Origin({
66+
srcEid: 0,
67+
sender: bytes32(uint256(uint160(sourceAuthority))),
68+
nonce: 1
69+
}),
70+
bytes32(0),
71+
abi.encodeCall(MessageOrdering.push, (1)),
72+
address(0),
73+
""
74+
);
75+
}
76+
77+
function test_lzReceive_revertsOnlyPeer() public {
78+
destinationEndpointId = LZForwarder.ENDPOINT_ID_BASE;
79+
destinationEndpoint = LZForwarder.ENDPOINT_BASE;
80+
initBaseContracts(getChain("base").createFork());
81+
82+
destination.selectFork();
83+
84+
vm.prank(bridge.destinationCrossChainMessenger);
85+
vm.expectRevert(abi.encodeWithSelector(OnlyPeer.selector, sourceEndpointId, bytes32(uint256(uint160(randomAddress)))));
86+
LZReceiver(destinationReceiver).lzReceive(
87+
Origin({
88+
srcEid: sourceEndpointId,
89+
sender: bytes32(uint256(uint160(randomAddress))),
90+
nonce: 1
91+
}),
92+
bytes32(0),
93+
abi.encodeCall(MessageOrdering.push, (1)),
94+
address(0),
95+
""
96+
);
97+
}
98+
5199
function test_invalidSourceEid() public {
52100
destinationEndpointId = LZForwarder.ENDPOINT_ID_BASE;
53101
destinationEndpoint = LZForwarder.ENDPOINT_BASE;
54102
initBaseContracts(getChain("base").createFork());
55103

56104
destination.selectFork();
57105

106+
// NOTE: To pass initial check, we set the peer.
107+
vm.prank(makeAddr("owner"));
108+
LZReceiver(destinationReceiver).setPeer(0, bytes32(uint256(uint160(sourceAuthority))));
109+
58110
vm.prank(bridge.destinationCrossChainMessenger);
59111
vm.expectRevert("LZReceiver/invalid-srcEid");
60112
LZReceiver(destinationReceiver).lzReceive(
@@ -77,6 +129,10 @@ contract LZIntegrationTest is IntegrationBaseTest {
77129

78130
destination.selectFork();
79131

132+
// NOTE: To pass initial check, we set the peer.
133+
vm.prank(makeAddr("owner"));
134+
LZReceiver(destinationReceiver).setPeer(sourceEndpointId, bytes32(uint256(uint160(randomAddress))));
135+
80136
vm.prank(bridge.destinationCrossChainMessenger);
81137
vm.expectRevert("LZReceiver/invalid-sourceAuthority");
82138
LZReceiver(destinationReceiver).lzReceive(
@@ -107,11 +163,25 @@ contract LZIntegrationTest is IntegrationBaseTest {
107163
}
108164

109165
function initSourceReceiver() internal override returns (address) {
110-
return address(new LZReceiver(sourceEndpoint, destinationEndpointId, bytes32(uint256(uint160(destinationAuthority))), address(moSource)));
166+
return address(new LZReceiver(
167+
sourceEndpoint,
168+
destinationEndpointId,
169+
bytes32(uint256(uint160(destinationAuthority))),
170+
address(moSource),
171+
makeAddr("delegate"),
172+
makeAddr("owner")
173+
));
111174
}
112175

113176
function initDestinationReceiver() internal override returns (address) {
114-
return address(new LZReceiver(destinationEndpoint, sourceEndpointId, bytes32(uint256(uint160(sourceAuthority))), address(moDestination)));
177+
return address(new LZReceiver(
178+
destinationEndpoint,
179+
sourceEndpointId,
180+
bytes32(uint256(uint160(sourceAuthority))),
181+
address(moDestination),
182+
makeAddr("delegate"),
183+
makeAddr("owner")
184+
));
115185
}
116186

117187
function initBridgeTesting() internal override returns (Bridge memory) {

test/LZReceiver.t.sol

Lines changed: 90 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ import { TargetContractMock } from "test/mocks/TargetContractMock.sol";
88
import { LZForwarder } from "src/forwarders/LZForwarder.sol";
99
import { LZReceiver, Origin } from "src/receivers/LZReceiver.sol";
1010

11+
interface ILayerZeroEndpointV2 {
12+
function delegates(address sender) external view returns (address);
13+
}
14+
1115
contract LZReceiverTest is Test {
1216

1317
TargetContractMock target;
@@ -17,41 +21,82 @@ contract LZReceiverTest is Test {
1721
address destinationEndpoint = LZForwarder.ENDPOINT_BNB;
1822
address randomAddress = makeAddr("randomAddress");
1923
address sourceAuthority = makeAddr("sourceAuthority");
20-
24+
address delegate = makeAddr("delegate");
25+
address owner = makeAddr("owner");
26+
2127
uint32 srcEid = LZForwarder.ENDPOINT_ID_ETHEREUM;
2228

29+
error NoPeer(uint32 eid);
30+
error OnlyEndpoint(address addr);
31+
error OnlyPeer(uint32 eid, bytes32 sender);
32+
2333
function setUp() public {
34+
vm.createSelectFork(getChain("bnb_smart_chain").rpcUrl);
35+
2436
target = new TargetContractMock();
2537

2638
receiver = new LZReceiver(
2739
destinationEndpoint,
2840
srcEid,
2941
bytes32(uint256(uint160(sourceAuthority))),
30-
address(target)
42+
address(target),
43+
delegate,
44+
owner
3145
);
3246
}
3347

34-
function test_constructor() public {
35-
receiver = new LZReceiver(
36-
destinationEndpoint,
37-
srcEid,
38-
bytes32(uint256(uint160(sourceAuthority))),
39-
address(target)
40-
);
48+
function test_constructor() public view {
49+
assertEq(receiver.srcEid(), srcEid);
50+
assertEq(receiver.sourceAuthority(), bytes32(uint256(uint160(sourceAuthority))));
51+
assertEq(receiver.target(), address(target));
52+
assertEq(receiver.owner(), owner);
53+
assertEq(receiver.peers(srcEid), bytes32(uint256(uint160(sourceAuthority))));
4154

42-
assertEq(receiver.destinationEndpoint(), destinationEndpoint);
43-
assertEq(receiver.srcEid(), srcEid);
44-
assertEq(receiver.sourceAuthority(), bytes32(uint256(uint160(sourceAuthority))));
45-
assertEq(receiver.target(), address(target));
55+
assertEq(
56+
ILayerZeroEndpointV2(address(receiver.endpoint())).delegates(address(receiver)),
57+
delegate
58+
);
4659
}
4760

48-
function test_lzReceive_invalidSender() public {
61+
function test_invalidEndpoint() public {
4962
vm.prank(randomAddress);
50-
vm.expectRevert("LZReceiver/invalid-sender");
63+
vm.expectRevert(abi.encodeWithSelector(OnlyEndpoint.selector, randomAddress));
5164
receiver.lzReceive(
5265
Origin({
5366
srcEid: srcEid,
54-
sender: bytes32(uint256(uint160(sourceAuthority))),
67+
sender: bytes32(uint256(uint160(randomAddress))),
68+
nonce: 1
69+
}),
70+
bytes32(0),
71+
abi.encodeCall(TargetContractMock.increment, ()),
72+
address(0),
73+
""
74+
);
75+
}
76+
77+
function test_lzReceive_revertsNoPeer() public {
78+
vm.prank(destinationEndpoint);
79+
vm.expectRevert(abi.encodeWithSelector(NoPeer.selector, 0));
80+
receiver.lzReceive(
81+
Origin({
82+
srcEid: 0,
83+
sender: bytes32(uint256(uint160(randomAddress))),
84+
nonce: 1
85+
}),
86+
bytes32(0),
87+
abi.encodeCall(TargetContractMock.increment, ()),
88+
address(0),
89+
""
90+
);
91+
}
92+
93+
function test_lzReceive_revertsOnlyPeer() public {
94+
vm.prank(destinationEndpoint);
95+
vm.expectRevert(abi.encodeWithSelector(OnlyPeer.selector, srcEid, bytes32(uint256(uint160(randomAddress)))));
96+
receiver.lzReceive(
97+
Origin({
98+
srcEid: srcEid,
99+
sender: bytes32(uint256(uint160(randomAddress))),
55100
nonce: 1
56101
}),
57102
bytes32(0),
@@ -62,6 +107,10 @@ contract LZReceiverTest is Test {
62107
}
63108

64109
function test_lzReceive_invalidSrcEid() public {
110+
// NOTE: To pass initial check, we set the peer.
111+
vm.prank(owner);
112+
receiver.setPeer(srcEid + 1, bytes32(uint256(uint160(sourceAuthority))));
113+
65114
vm.prank(destinationEndpoint);
66115
vm.expectRevert("LZReceiver/invalid-srcEid");
67116
receiver.lzReceive(
@@ -78,6 +127,10 @@ contract LZReceiverTest is Test {
78127
}
79128

80129
function test_lzReceive_invalidSourceAuthority() public {
130+
// NOTE: To pass initial check, we set the peer.
131+
vm.prank(owner);
132+
receiver.setPeer(srcEid, bytes32(uint256(uint160(randomAddress))));
133+
81134
vm.prank(destinationEndpoint);
82135
vm.expectRevert("LZReceiver/invalid-sourceAuthority");
83136
receiver.lzReceive(
@@ -110,34 +163,45 @@ contract LZReceiverTest is Test {
110163
assertEq(target.count(), 1);
111164
}
112165

113-
function test_allowInitializePath() public view {
114-
// Should return true when origin.srcEid == srcEid and origin.sender == sourceAuthority
166+
function test_allowInitializePath() public {
167+
// Should return true when origin.srcEid == srcEid, origin.sender == sourceAuthority and peers[origin.srcEid] == origin.sender
115168
assertTrue(receiver.allowInitializePath(Origin({
116169
srcEid: srcEid,
117170
sender: bytes32(uint256(uint160(sourceAuthority))),
118171
nonce: 1
119172
})));
120173

174+
// Should return false when peers[origin.srcEid] != origin.sender
175+
176+
assertFalse(receiver.allowInitializePath(Origin({
177+
srcEid: srcEid,
178+
sender: bytes32(uint256(uint160(randomAddress))),
179+
nonce: 1
180+
})));
181+
121182
// Should return false when origin.srcEid != srcEid
183+
184+
// NOTE: Setting peer to make `super.allowInitializePath(origin)` return true
185+
vm.prank(owner);
186+
receiver.setPeer(srcEid + 1, bytes32(uint256(uint160(sourceAuthority))));
187+
122188
assertFalse(receiver.allowInitializePath(Origin({
123189
srcEid: srcEid + 1,
124190
sender: bytes32(uint256(uint160(sourceAuthority))),
125191
nonce: 1
126192
})));
127193

128194
// Should return false when origin.sender != sourceAuthority
129-
assertFalse(receiver.allowInitializePath(Origin({
130-
srcEid: srcEid,
131-
sender: bytes32(uint256(uint160(randomAddress))),
132-
nonce: 1
133-
})));
134195

135-
// Should return false when origin.srcEid != srcEid and origin.sender != sourceAuthority
196+
// NOTE: Setting peer to make `super.allowInitializePath(origin)` return true
197+
vm.prank(owner);
198+
receiver.setPeer(srcEid, bytes32(uint256(uint160(randomAddress))));
199+
136200
assertFalse(receiver.allowInitializePath(Origin({
137-
srcEid: srcEid + 1,
201+
srcEid: srcEid,
138202
sender: bytes32(uint256(uint160(randomAddress))),
139203
nonce: 1
140204
})));
141205
}
142-
206+
143207
}

0 commit comments

Comments
 (0)