Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 17 additions & 23 deletions src/receivers/LZReceiver.sol
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
// SPDX-License-Identifier: AGPL-3.0-or-later
pragma solidity ^0.8.0;

import { Address } from "lib/openzeppelin-contracts/contracts/utils/Address.sol";
import { Address } from "openzeppelin-contracts/contracts/utils/Address.sol";
import { Ownable } from "openzeppelin-contracts/contracts/access/Ownable.sol";

struct Origin {
uint32 srcEid;
bytes32 sender;
uint64 nonce;
}
import { OApp, Origin } from "layerzerolabs/oapp-evm/contracts/oapp/OApp.sol";

/**
* @title LZReceiver
* @notice Receive messages from LayerZero-style bridge.
*/
contract LZReceiver {
contract LZReceiver is OApp {

using Address for address;

address public immutable destinationEndpoint;
address public immutable target;

uint32 public immutable srcEid;
Expand All @@ -28,30 +24,28 @@ contract LZReceiver {
address _destinationEndpoint,
uint32 _srcEid,
bytes32 _sourceAuthority,
address _target
) {
destinationEndpoint = _destinationEndpoint;
target = _target;
sourceAuthority = _sourceAuthority;
srcEid = _srcEid;
address _target,
address _delegate,
address _owner
) OApp(_destinationEndpoint, _delegate) Ownable(_owner) {
target = _target;
sourceAuthority = _sourceAuthority;
srcEid = _srcEid;

_setPeer(_srcEid, _sourceAuthority);
}

function lzReceive(
function _lzReceive(
Origin calldata _origin,
bytes32, // _guid
bytes calldata _message,
address, // _executor,
address, // _executor
bytes calldata // _extraData
) external {
require(msg.sender == destinationEndpoint, "LZReceiver/invalid-sender");
) internal override {
require(_origin.srcEid == srcEid, "LZReceiver/invalid-srcEid");
require(_origin.sender == sourceAuthority, "LZReceiver/invalid-sourceAuthority");

target.functionCall(_message);
}

function allowInitializePath(Origin calldata origin) public view returns (bool) {
return origin.srcEid == srcEid && origin.sender == sourceAuthority;
target.functionCallWithValue(_message, msg.value);
}

}
78 changes: 74 additions & 4 deletions test/LZIntegration.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,19 @@ contract LZIntegrationTest is IntegrationBaseTest {
Domain destination2;
Bridge bridge2;

function test_invalidSender() public {
error NoPeer(uint32 eid);
error OnlyEndpoint(address addr);
error OnlyPeer(uint32 eid, bytes32 sender);

function test_invalidEndpoint() public {
destinationEndpointId = LZForwarder.ENDPOINT_ID_BASE;
destinationEndpoint = LZForwarder.ENDPOINT_BASE;
initBaseContracts(getChain("base").createFork());

destination.selectFork();

vm.prank(randomAddress);
vm.expectRevert("LZReceiver/invalid-sender");
vm.expectRevert(abi.encodeWithSelector(OnlyEndpoint.selector, randomAddress));
LZReceiver(destinationReceiver).lzReceive(
Origin({
srcEid: sourceEndpointId,
Expand All @@ -48,13 +52,61 @@ contract LZIntegrationTest is IntegrationBaseTest {
);
}

function test_lzReceive_revertsNoPeer() public {
destinationEndpointId = LZForwarder.ENDPOINT_ID_BASE;
destinationEndpoint = LZForwarder.ENDPOINT_BASE;
initBaseContracts(getChain("base").createFork());

destination.selectFork();

vm.prank(bridge.destinationCrossChainMessenger);
vm.expectRevert(abi.encodeWithSelector(NoPeer.selector, 0));
LZReceiver(destinationReceiver).lzReceive(
Origin({
srcEid: 0,
sender: bytes32(uint256(uint160(sourceAuthority))),
nonce: 1
}),
bytes32(0),
abi.encodeCall(MessageOrdering.push, (1)),
address(0),
""
);
}

function test_lzReceive_revertsOnlyPeer() public {
destinationEndpointId = LZForwarder.ENDPOINT_ID_BASE;
destinationEndpoint = LZForwarder.ENDPOINT_BASE;
initBaseContracts(getChain("base").createFork());

destination.selectFork();

vm.prank(bridge.destinationCrossChainMessenger);
vm.expectRevert(abi.encodeWithSelector(OnlyPeer.selector, sourceEndpointId, bytes32(uint256(uint160(randomAddress)))));
LZReceiver(destinationReceiver).lzReceive(
Origin({
srcEid: sourceEndpointId,
sender: bytes32(uint256(uint160(randomAddress))),
nonce: 1
}),
bytes32(0),
abi.encodeCall(MessageOrdering.push, (1)),
address(0),
""
);
}

function test_invalidSourceEid() public {
destinationEndpointId = LZForwarder.ENDPOINT_ID_BASE;
destinationEndpoint = LZForwarder.ENDPOINT_BASE;
initBaseContracts(getChain("base").createFork());

destination.selectFork();

// NOTE: To pass initial check, we set the peer.
vm.prank(makeAddr("owner"));
LZReceiver(destinationReceiver).setPeer(0, bytes32(uint256(uint160(sourceAuthority))));

vm.prank(bridge.destinationCrossChainMessenger);
vm.expectRevert("LZReceiver/invalid-srcEid");
LZReceiver(destinationReceiver).lzReceive(
Expand All @@ -77,6 +129,10 @@ contract LZIntegrationTest is IntegrationBaseTest {

destination.selectFork();

// NOTE: To pass initial check, we set the peer.
vm.prank(makeAddr("owner"));
LZReceiver(destinationReceiver).setPeer(sourceEndpointId, bytes32(uint256(uint160(randomAddress))));

vm.prank(bridge.destinationCrossChainMessenger);
vm.expectRevert("LZReceiver/invalid-sourceAuthority");
LZReceiver(destinationReceiver).lzReceive(
Expand Down Expand Up @@ -107,11 +163,25 @@ contract LZIntegrationTest is IntegrationBaseTest {
}

function initSourceReceiver() internal override returns (address) {
return address(new LZReceiver(sourceEndpoint, destinationEndpointId, bytes32(uint256(uint160(destinationAuthority))), address(moSource)));
return address(new LZReceiver(
sourceEndpoint,
destinationEndpointId,
bytes32(uint256(uint160(destinationAuthority))),
address(moSource),
makeAddr("delegate"),
makeAddr("owner")
));
}

function initDestinationReceiver() internal override returns (address) {
return address(new LZReceiver(destinationEndpoint, sourceEndpointId, bytes32(uint256(uint160(sourceAuthority))), address(moDestination)));
return address(new LZReceiver(
destinationEndpoint,
sourceEndpointId,
bytes32(uint256(uint160(sourceAuthority))),
address(moDestination),
makeAddr("delegate"),
makeAddr("owner")
));
}

function initBridgeTesting() internal override returns (Bridge memory) {
Expand Down
115 changes: 69 additions & 46 deletions test/LZReceiver.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import { TargetContractMock } from "test/mocks/TargetContractMock.sol";
import { LZForwarder } from "src/forwarders/LZForwarder.sol";
import { LZReceiver, Origin } from "src/receivers/LZReceiver.sol";

interface ILayerZeroEndpointV2 {
function delegates(address sender) external view returns (address);
}

contract LZReceiverTest is Test {

TargetContractMock target;
Expand All @@ -17,41 +21,82 @@ contract LZReceiverTest is Test {
address destinationEndpoint = LZForwarder.ENDPOINT_BNB;
address randomAddress = makeAddr("randomAddress");
address sourceAuthority = makeAddr("sourceAuthority");

address delegate = makeAddr("delegate");
address owner = makeAddr("owner");

uint32 srcEid = LZForwarder.ENDPOINT_ID_ETHEREUM;

error NoPeer(uint32 eid);
error OnlyEndpoint(address addr);
error OnlyPeer(uint32 eid, bytes32 sender);

function setUp() public {
vm.createSelectFork(getChain("bnb_smart_chain").rpcUrl);

target = new TargetContractMock();

receiver = new LZReceiver(
destinationEndpoint,
srcEid,
bytes32(uint256(uint160(sourceAuthority))),
address(target)
address(target),
delegate,
owner
);
}

function test_constructor() public {
receiver = new LZReceiver(
destinationEndpoint,
srcEid,
bytes32(uint256(uint160(sourceAuthority))),
address(target)
);
function test_constructor() public view {
assertEq(receiver.srcEid(), srcEid);
assertEq(receiver.sourceAuthority(), bytes32(uint256(uint160(sourceAuthority))));
assertEq(receiver.target(), address(target));
assertEq(receiver.owner(), owner);
assertEq(receiver.peers(srcEid), bytes32(uint256(uint160(sourceAuthority))));

assertEq(receiver.destinationEndpoint(), destinationEndpoint);
assertEq(receiver.srcEid(), srcEid);
assertEq(receiver.sourceAuthority(), bytes32(uint256(uint160(sourceAuthority))));
assertEq(receiver.target(), address(target));
assertEq(
ILayerZeroEndpointV2(address(receiver.endpoint())).delegates(address(receiver)),
delegate
);
}

function test_lzReceive_invalidSender() public {
function test_invalidEndpoint() public {
vm.prank(randomAddress);
vm.expectRevert("LZReceiver/invalid-sender");
vm.expectRevert(abi.encodeWithSelector(OnlyEndpoint.selector, randomAddress));
receiver.lzReceive(
Origin({
srcEid: srcEid,
sender: bytes32(uint256(uint160(sourceAuthority))),
sender: bytes32(uint256(uint160(randomAddress))),
nonce: 1
}),
bytes32(0),
abi.encodeCall(TargetContractMock.increment, ()),
address(0),
""
);
}

function test_lzReceive_revertsNoPeer() public {
vm.prank(destinationEndpoint);
vm.expectRevert(abi.encodeWithSelector(NoPeer.selector, 0));
receiver.lzReceive(
Origin({
srcEid: 0,
sender: bytes32(uint256(uint160(randomAddress))),
nonce: 1
}),
bytes32(0),
abi.encodeCall(TargetContractMock.increment, ()),
address(0),
""
);
}

function test_lzReceive_revertsOnlyPeer() public {
vm.prank(destinationEndpoint);
vm.expectRevert(abi.encodeWithSelector(OnlyPeer.selector, srcEid, bytes32(uint256(uint160(randomAddress)))));
receiver.lzReceive(
Origin({
srcEid: srcEid,
sender: bytes32(uint256(uint160(randomAddress))),
nonce: 1
}),
bytes32(0),
Expand All @@ -62,6 +107,10 @@ contract LZReceiverTest is Test {
}

function test_lzReceive_invalidSrcEid() public {
// NOTE: To pass initial check, we set the peer.
vm.prank(owner);
receiver.setPeer(srcEid + 1, bytes32(uint256(uint160(sourceAuthority))));

vm.prank(destinationEndpoint);
vm.expectRevert("LZReceiver/invalid-srcEid");
receiver.lzReceive(
Expand All @@ -78,6 +127,10 @@ contract LZReceiverTest is Test {
}

function test_lzReceive_invalidSourceAuthority() public {
// NOTE: To pass initial check, we set the peer.
vm.prank(owner);
receiver.setPeer(srcEid, bytes32(uint256(uint160(randomAddress))));

vm.prank(destinationEndpoint);
vm.expectRevert("LZReceiver/invalid-sourceAuthority");
receiver.lzReceive(
Expand Down Expand Up @@ -110,34 +163,4 @@ contract LZReceiverTest is Test {
assertEq(target.count(), 1);
}

function test_allowInitializePath() public view {
// Should return true when origin.srcEid == srcEid and origin.sender == sourceAuthority
assertTrue(receiver.allowInitializePath(Origin({
srcEid: srcEid,
sender: bytes32(uint256(uint160(sourceAuthority))),
nonce: 1
})));

// Should return false when origin.srcEid != srcEid
assertFalse(receiver.allowInitializePath(Origin({
srcEid: srcEid + 1,
sender: bytes32(uint256(uint160(sourceAuthority))),
nonce: 1
})));

// Should return false when origin.sender != sourceAuthority
assertFalse(receiver.allowInitializePath(Origin({
srcEid: srcEid,
sender: bytes32(uint256(uint160(randomAddress))),
nonce: 1
})));

// Should return false when origin.srcEid != srcEid and origin.sender != sourceAuthority
assertFalse(receiver.allowInitializePath(Origin({
srcEid: srcEid + 1,
sender: bytes32(uint256(uint160(randomAddress))),
nonce: 1
})));
}

}
Loading