diff --git a/src/receivers/LZReceiver.sol b/src/receivers/LZReceiver.sol index 59252d0..d16ec90 100644 --- a/src/receivers/LZReceiver.sol +++ b/src/receivers/LZReceiver.sol @@ -1,23 +1,19 @@ // SPDX-License-Identifier: AGPL-3.0-or-later pragma solidity ^0.8.0; -import { Address } from "lib/openzeppelin-contracts/contracts/utils/Address.sol"; +import { Address } from "openzeppelin-contracts/contracts/utils/Address.sol"; +import { Ownable } from "openzeppelin-contracts/contracts/access/Ownable.sol"; -struct Origin { - uint32 srcEid; - bytes32 sender; - uint64 nonce; -} +import { OApp, Origin } from "layerzerolabs/oapp-evm/contracts/oapp/OApp.sol"; /** * @title LZReceiver * @notice Receive messages from LayerZero-style bridge. */ -contract LZReceiver { +contract LZReceiver is OApp { using Address for address; - address public immutable destinationEndpoint; address public immutable target; uint32 public immutable srcEid; @@ -28,30 +24,34 @@ contract LZReceiver { address _destinationEndpoint, uint32 _srcEid, bytes32 _sourceAuthority, - address _target - ) { - destinationEndpoint = _destinationEndpoint; - target = _target; - sourceAuthority = _sourceAuthority; - srcEid = _srcEid; + address _target, + address _delegate, + address _owner + ) OApp(_destinationEndpoint, _delegate) Ownable(_owner) { + target = _target; + sourceAuthority = _sourceAuthority; + srcEid = _srcEid; + + _setPeer(_srcEid, _sourceAuthority); } - function lzReceive( + function _lzReceive( Origin calldata _origin, bytes32, // _guid bytes calldata _message, - address, // _executor, + address, // _executor bytes calldata // _extraData - ) external { - require(msg.sender == destinationEndpoint, "LZReceiver/invalid-sender"); + ) internal override { require(_origin.srcEid == srcEid, "LZReceiver/invalid-srcEid"); require(_origin.sender == sourceAuthority, "LZReceiver/invalid-sourceAuthority"); - target.functionCall(_message); + target.functionCallWithValue(_message, msg.value); } - function allowInitializePath(Origin calldata origin) public view returns (bool) { - return origin.srcEid == srcEid && origin.sender == sourceAuthority; + function allowInitializePath(Origin calldata origin) public view override returns (bool) { + return super.allowInitializePath(origin) + && origin.srcEid == srcEid + && origin.sender == sourceAuthority; } } diff --git a/test/LZIntegration.t.sol b/test/LZIntegration.t.sol index f6b0ea2..ffea7d3 100644 --- a/test/LZIntegration.t.sol +++ b/test/LZIntegration.t.sol @@ -26,7 +26,11 @@ contract LZIntegrationTest is IntegrationBaseTest { Domain destination2; Bridge bridge2; - function test_invalidSender() public { + error NoPeer(uint32 eid); + error OnlyEndpoint(address addr); + error OnlyPeer(uint32 eid, bytes32 sender); + + function test_invalidEndpoint() public { destinationEndpointId = LZForwarder.ENDPOINT_ID_BASE; destinationEndpoint = LZForwarder.ENDPOINT_BASE; initBaseContracts(getChain("base").createFork()); @@ -34,7 +38,7 @@ contract LZIntegrationTest is IntegrationBaseTest { destination.selectFork(); vm.prank(randomAddress); - vm.expectRevert("LZReceiver/invalid-sender"); + vm.expectRevert(abi.encodeWithSelector(OnlyEndpoint.selector, randomAddress)); LZReceiver(destinationReceiver).lzReceive( Origin({ srcEid: sourceEndpointId, @@ -48,6 +52,50 @@ contract LZIntegrationTest is IntegrationBaseTest { ); } + function test_lzReceive_revertsNoPeer() public { + destinationEndpointId = LZForwarder.ENDPOINT_ID_BASE; + destinationEndpoint = LZForwarder.ENDPOINT_BASE; + initBaseContracts(getChain("base").createFork()); + + destination.selectFork(); + + vm.prank(bridge.destinationCrossChainMessenger); + vm.expectRevert(abi.encodeWithSelector(NoPeer.selector, 0)); + LZReceiver(destinationReceiver).lzReceive( + Origin({ + srcEid: 0, + sender: bytes32(uint256(uint160(sourceAuthority))), + nonce: 1 + }), + bytes32(0), + abi.encodeCall(MessageOrdering.push, (1)), + address(0), + "" + ); + } + + function test_lzReceive_revertsOnlyPeer() public { + destinationEndpointId = LZForwarder.ENDPOINT_ID_BASE; + destinationEndpoint = LZForwarder.ENDPOINT_BASE; + initBaseContracts(getChain("base").createFork()); + + destination.selectFork(); + + vm.prank(bridge.destinationCrossChainMessenger); + vm.expectRevert(abi.encodeWithSelector(OnlyPeer.selector, sourceEndpointId, bytes32(uint256(uint160(randomAddress))))); + LZReceiver(destinationReceiver).lzReceive( + Origin({ + srcEid: sourceEndpointId, + sender: bytes32(uint256(uint160(randomAddress))), + nonce: 1 + }), + bytes32(0), + abi.encodeCall(MessageOrdering.push, (1)), + address(0), + "" + ); + } + function test_invalidSourceEid() public { destinationEndpointId = LZForwarder.ENDPOINT_ID_BASE; destinationEndpoint = LZForwarder.ENDPOINT_BASE; @@ -55,6 +103,10 @@ contract LZIntegrationTest is IntegrationBaseTest { destination.selectFork(); + // NOTE: To pass initial check, we set the peer. + vm.prank(makeAddr("owner")); + LZReceiver(destinationReceiver).setPeer(0, bytes32(uint256(uint160(sourceAuthority)))); + vm.prank(bridge.destinationCrossChainMessenger); vm.expectRevert("LZReceiver/invalid-srcEid"); LZReceiver(destinationReceiver).lzReceive( @@ -77,6 +129,10 @@ contract LZIntegrationTest is IntegrationBaseTest { destination.selectFork(); + // NOTE: To pass initial check, we set the peer. + vm.prank(makeAddr("owner")); + LZReceiver(destinationReceiver).setPeer(sourceEndpointId, bytes32(uint256(uint160(randomAddress)))); + vm.prank(bridge.destinationCrossChainMessenger); vm.expectRevert("LZReceiver/invalid-sourceAuthority"); LZReceiver(destinationReceiver).lzReceive( @@ -107,11 +163,25 @@ contract LZIntegrationTest is IntegrationBaseTest { } function initSourceReceiver() internal override returns (address) { - return address(new LZReceiver(sourceEndpoint, destinationEndpointId, bytes32(uint256(uint160(destinationAuthority))), address(moSource))); + return address(new LZReceiver( + sourceEndpoint, + destinationEndpointId, + bytes32(uint256(uint160(destinationAuthority))), + address(moSource), + makeAddr("delegate"), + makeAddr("owner") + )); } function initDestinationReceiver() internal override returns (address) { - return address(new LZReceiver(destinationEndpoint, sourceEndpointId, bytes32(uint256(uint160(sourceAuthority))), address(moDestination))); + return address(new LZReceiver( + destinationEndpoint, + sourceEndpointId, + bytes32(uint256(uint160(sourceAuthority))), + address(moDestination), + makeAddr("delegate"), + makeAddr("owner") + )); } function initBridgeTesting() internal override returns (Bridge memory) { diff --git a/test/LZReceiver.t.sol b/test/LZReceiver.t.sol index be55682..af820f7 100644 --- a/test/LZReceiver.t.sol +++ b/test/LZReceiver.t.sol @@ -8,6 +8,10 @@ import { TargetContractMock } from "test/mocks/TargetContractMock.sol"; import { LZForwarder } from "src/forwarders/LZForwarder.sol"; import { LZReceiver, Origin } from "src/receivers/LZReceiver.sol"; +interface ILayerZeroEndpointV2 { + function delegates(address sender) external view returns (address); +} + contract LZReceiverTest is Test { TargetContractMock target; @@ -17,41 +21,82 @@ contract LZReceiverTest is Test { address destinationEndpoint = LZForwarder.ENDPOINT_BNB; address randomAddress = makeAddr("randomAddress"); address sourceAuthority = makeAddr("sourceAuthority"); - + address delegate = makeAddr("delegate"); + address owner = makeAddr("owner"); + uint32 srcEid = LZForwarder.ENDPOINT_ID_ETHEREUM; + error NoPeer(uint32 eid); + error OnlyEndpoint(address addr); + error OnlyPeer(uint32 eid, bytes32 sender); + function setUp() public { + vm.createSelectFork(getChain("bnb_smart_chain").rpcUrl); + target = new TargetContractMock(); receiver = new LZReceiver( destinationEndpoint, srcEid, bytes32(uint256(uint160(sourceAuthority))), - address(target) + address(target), + delegate, + owner ); } - function test_constructor() public { - receiver = new LZReceiver( - destinationEndpoint, - srcEid, - bytes32(uint256(uint160(sourceAuthority))), - address(target) - ); + function test_constructor() public view { + assertEq(receiver.srcEid(), srcEid); + assertEq(receiver.sourceAuthority(), bytes32(uint256(uint160(sourceAuthority)))); + assertEq(receiver.target(), address(target)); + assertEq(receiver.owner(), owner); + assertEq(receiver.peers(srcEid), bytes32(uint256(uint160(sourceAuthority)))); - assertEq(receiver.destinationEndpoint(), destinationEndpoint); - assertEq(receiver.srcEid(), srcEid); - assertEq(receiver.sourceAuthority(), bytes32(uint256(uint160(sourceAuthority)))); - assertEq(receiver.target(), address(target)); + assertEq( + ILayerZeroEndpointV2(address(receiver.endpoint())).delegates(address(receiver)), + delegate + ); } - function test_lzReceive_invalidSender() public { + function test_invalidEndpoint() public { vm.prank(randomAddress); - vm.expectRevert("LZReceiver/invalid-sender"); + vm.expectRevert(abi.encodeWithSelector(OnlyEndpoint.selector, randomAddress)); receiver.lzReceive( Origin({ srcEid: srcEid, - sender: bytes32(uint256(uint160(sourceAuthority))), + sender: bytes32(uint256(uint160(randomAddress))), + nonce: 1 + }), + bytes32(0), + abi.encodeCall(TargetContractMock.increment, ()), + address(0), + "" + ); + } + + function test_lzReceive_revertsNoPeer() public { + vm.prank(destinationEndpoint); + vm.expectRevert(abi.encodeWithSelector(NoPeer.selector, 0)); + receiver.lzReceive( + Origin({ + srcEid: 0, + sender: bytes32(uint256(uint160(randomAddress))), + nonce: 1 + }), + bytes32(0), + abi.encodeCall(TargetContractMock.increment, ()), + address(0), + "" + ); + } + + function test_lzReceive_revertsOnlyPeer() public { + vm.prank(destinationEndpoint); + vm.expectRevert(abi.encodeWithSelector(OnlyPeer.selector, srcEid, bytes32(uint256(uint160(randomAddress))))); + receiver.lzReceive( + Origin({ + srcEid: srcEid, + sender: bytes32(uint256(uint160(randomAddress))), nonce: 1 }), bytes32(0), @@ -62,6 +107,10 @@ contract LZReceiverTest is Test { } function test_lzReceive_invalidSrcEid() public { + // NOTE: To pass initial check, we set the peer. + vm.prank(owner); + receiver.setPeer(srcEid + 1, bytes32(uint256(uint160(sourceAuthority)))); + vm.prank(destinationEndpoint); vm.expectRevert("LZReceiver/invalid-srcEid"); receiver.lzReceive( @@ -78,6 +127,10 @@ contract LZReceiverTest is Test { } function test_lzReceive_invalidSourceAuthority() public { + // NOTE: To pass initial check, we set the peer. + vm.prank(owner); + receiver.setPeer(srcEid, bytes32(uint256(uint160(randomAddress)))); + vm.prank(destinationEndpoint); vm.expectRevert("LZReceiver/invalid-sourceAuthority"); receiver.lzReceive( @@ -110,15 +163,28 @@ contract LZReceiverTest is Test { assertEq(target.count(), 1); } - function test_allowInitializePath() public view { - // Should return true when origin.srcEid == srcEid and origin.sender == sourceAuthority + function test_allowInitializePath() public { + // Should return true when origin.srcEid == srcEid, origin.sender == sourceAuthority and peers[origin.srcEid] == origin.sender assertTrue(receiver.allowInitializePath(Origin({ srcEid: srcEid, sender: bytes32(uint256(uint160(sourceAuthority))), nonce: 1 }))); + // Should return false when peers[origin.srcEid] != origin.sender + + assertFalse(receiver.allowInitializePath(Origin({ + srcEid: srcEid, + sender: bytes32(uint256(uint160(randomAddress))), + nonce: 1 + }))); + // Should return false when origin.srcEid != srcEid + + // NOTE: Setting peer to make `super.allowInitializePath(origin)` return true + vm.prank(owner); + receiver.setPeer(srcEid + 1, bytes32(uint256(uint160(sourceAuthority)))); + assertFalse(receiver.allowInitializePath(Origin({ srcEid: srcEid + 1, sender: bytes32(uint256(uint160(sourceAuthority))), @@ -126,18 +192,16 @@ contract LZReceiverTest is Test { }))); // Should return false when origin.sender != sourceAuthority - assertFalse(receiver.allowInitializePath(Origin({ - srcEid: srcEid, - sender: bytes32(uint256(uint160(randomAddress))), - nonce: 1 - }))); - // Should return false when origin.srcEid != srcEid and origin.sender != sourceAuthority + // NOTE: Setting peer to make `super.allowInitializePath(origin)` return true + vm.prank(owner); + receiver.setPeer(srcEid, bytes32(uint256(uint160(randomAddress)))); + assertFalse(receiver.allowInitializePath(Origin({ - srcEid: srcEid + 1, + srcEid: srcEid, sender: bytes32(uint256(uint160(randomAddress))), nonce: 1 }))); } - + }