diff --git a/evm/script/helpers/DeployWormholeNttBase.sol b/evm/script/helpers/DeployWormholeNttBase.sol index a8b236b83..05851ab42 100644 --- a/evm/script/helpers/DeployWormholeNttBase.sol +++ b/evm/script/helpers/DeployWormholeNttBase.sol @@ -100,7 +100,8 @@ contract DeployWormholeNttBase is ParseNttConfig { } // Hardcoded to one since these scripts handle Wormhole-only deployments. - INttManager(nttManager).setThreshold(1); + // TODO: We need the sourceChainId to set the threshold. Also need to enable sending and receiving. + // INttManager(nttManager).setThreshold(1); console2.log("Threshold set on NttManager: %d", uint256(1)); } diff --git a/evm/src/NttManager/ManagerBase.sol b/evm/src/NttManager/ManagerBase.sol index 8a336fccc..9a4538442 100644 --- a/evm/src/NttManager/ManagerBase.sol +++ b/evm/src/NttManager/ManagerBase.sol @@ -52,7 +52,6 @@ abstract contract ManagerBase is } function _migrate() internal virtual override { - _checkThresholdInvariants(); _checkTransceiversInvariants(); } @@ -68,13 +67,6 @@ abstract contract ManagerBase is // =============== Storage Getters/Setters ============================================== - function _getThresholdStorage() private pure returns (_Threshold storage $) { - uint256 slot = uint256(THRESHOLD_SLOT); - assembly ("memory-safe") { - $.slot := slot - } - } - function _getMessageAttestationsStorage() internal pure @@ -139,15 +131,19 @@ abstract contract ManagerBase is bytes32 nttManagerMessageHash = TransceiverStructs.nttManagerMessageDigest(sourceChainId, payload); + // The `msg.sender` is the transceiver. Get the index for it. + uint8 index = _getTransceiverInfosStorage()[msg.sender].index; + + // TODO: Is there a race condition with disabling a transceiver while a tx is outstanding? + if (!_isRecvTransceiverEnabledForChain(sourceChainId, index)) { + revert CallerNotTransceiver(msg.sender); + } + // set the attested flag for this transceiver. // NOTE: Attestation is idempotent (bitwise or 1), but we revert // anyway to ensure that the client does not continue to initiate calls // to receive the same message through the same transceiver. - if ( - transceiverAttestedToMessage( - nttManagerMessageHash, _getTransceiverInfosStorage()[msg.sender].index - ) - ) { + if (transceiverAttestedToMessage(nttManagerMessageHash, index)) { revert TransceiverAlreadyAttestedToMessage(nttManagerMessageHash); } _setTransceiverAttestedToMessage(nttManagerMessageHash, msg.sender); @@ -162,7 +158,7 @@ abstract contract ManagerBase is ) internal returns (bytes32, bool) { bytes32 digest = TransceiverStructs.nttManagerMessageDigest(sourceChainId, message); - if (!isMessageApproved(digest)) { + if (!isMessageApproved(sourceChainId, digest)) { revert MessageNotApproved(digest); } @@ -225,7 +221,7 @@ abstract contract ManagerBase is ) { // cache enabled transceivers to avoid multiple storage reads - address[] memory enabledTransceivers = _getEnabledTransceiversStorage(); + address[] memory enabledTransceivers = getEnabledSendTransceiversForChain(recipientChain); TransceiverStructs.TransceiverInstruction[] memory instructions; @@ -280,15 +276,16 @@ abstract contract ManagerBase is } /// @inheritdoc IManagerBase - function getThreshold() public view returns (uint8) { - return _getThresholdStorage().num; + /// @dev This is here because it is defined in IManagerBase. + function getThreshold( + uint16 sourceChainId + ) public view returns (uint8) { + return _getPerChainRecvTransceiverDataStorage()[sourceChainId].threshold; } /// @inheritdoc IManagerBase - function isMessageApproved( - bytes32 digest - ) public view returns (bool) { - uint8 threshold = getThreshold(); + function isMessageApproved(uint16 sourceChainId, bytes32 digest) public view returns (bool) { + uint8 threshold = getThreshold(sourceChainId); return messageAttestations(digest) >= threshold && threshold > 0; } @@ -354,28 +351,7 @@ abstract contract ManagerBase is address transceiver ) external onlyOwner { _setTransceiver(transceiver); - - _Threshold storage _threshold = _getThresholdStorage(); - // We do not automatically increase the threshold here. - // Automatically increasing the threshold can result in a scenario - // where in-flight messages can't be redeemed. - // For example: Assume there is 1 Transceiver and the threshold is 1. - // If we were to add a new Transceiver, the threshold would increase to 2. - // However, all messages that are either in-flight or that are sent on - // a source chain that does not yet have 2 Transceivers will only have been - // sent from a single transceiver, so they would never be able to get - // redeemed. - // Instead, we leave it up to the owner to manually update the threshold - // after some period of time, ideally once all chains have the new Transceiver - // and transfers that were sent via the old configuration are all complete. - // However if the threshold is 0 (the initial case) we do increment to 1. - if (_threshold.num == 0) { - _threshold.num = 1; - } - - emit TransceiverAdded(transceiver, _getNumTransceiversStorage().enabled, _threshold.num); - - _checkThresholdInvariants(); + emit TransceiverAdded(transceiver, _getNumTransceiversStorage().enabled); } /// @inheritdoc IManagerBase @@ -383,34 +359,13 @@ abstract contract ManagerBase is address transceiver ) external onlyOwner { _removeTransceiver(transceiver); - - _Threshold storage _threshold = _getThresholdStorage(); - uint8 numEnabledTransceivers = _getNumTransceiversStorage().enabled; - - if (numEnabledTransceivers < _threshold.num) { - _threshold.num = numEnabledTransceivers; - } - - emit TransceiverRemoved(transceiver, _threshold.num); - - _checkThresholdInvariants(); + emit TransceiverRemoved(transceiver); } /// @inheritdoc IManagerBase - function setThreshold( - uint8 threshold - ) external onlyOwner { - if (threshold == 0) { - revert ZeroThreshold(); - } - - _Threshold storage _threshold = _getThresholdStorage(); - uint8 oldThreshold = _threshold.num; - - _threshold.num = threshold; - _checkThresholdInvariants(); - - emit ThresholdChanged(oldThreshold, threshold); + /// @dev This is here because it is defined in IManagerBase. + function setThreshold(uint16 sourceChainId, uint8 threshold) external onlyOwner { + _setThreshold(sourceChainId, threshold); } // =============== Internal ============================================================== @@ -480,20 +435,4 @@ abstract contract ManagerBase is ); } } - - function _checkThresholdInvariants() internal view { - uint8 threshold = _getThresholdStorage().num; - _NumTransceivers memory numTransceivers = _getNumTransceiversStorage(); - - // invariant: threshold <= enabledTransceivers.length - if (threshold > numTransceivers.enabled) { - revert ThresholdTooHigh(threshold, numTransceivers.enabled); - } - - if (numTransceivers.registered > 0) { - if (threshold == 0) { - revert ZeroThreshold(); - } - } - } } diff --git a/evm/src/NttManager/NttManager.sol b/evm/src/NttManager/NttManager.sol index 7ac021b91..469ca5b15 100644 --- a/evm/src/NttManager/NttManager.sol +++ b/evm/src/NttManager/NttManager.sol @@ -68,7 +68,6 @@ contract NttManager is INttManager, RateLimiter, ManagerBase { function _initialize() internal virtual override { __NttManager_init(); - _checkThresholdInvariants(); _checkTransceiversInvariants(); } @@ -192,7 +191,7 @@ contract NttManager is INttManager, RateLimiter, ManagerBase { // Compute manager message digest and record transceiver attestation. bytes32 nttManagerMessageHash = _recordTransceiverAttestation(sourceChainId, payload); - if (isMessageApproved(nttManagerMessageHash)) { + if (isMessageApproved(sourceChainId, nttManagerMessageHash)) { executeMsg(sourceChainId, sourceNttManagerAddress, payload); } } diff --git a/evm/src/NttManager/TransceiverRegistry.sol b/evm/src/NttManager/TransceiverRegistry.sol index 4e97992bd..bc32753cb 100644 --- a/evm/src/NttManager/TransceiverRegistry.sol +++ b/evm/src/NttManager/TransceiverRegistry.sol @@ -1,7 +1,11 @@ // SPDX-License-Identifier: Apache 2 pragma solidity >=0.8.8 <0.9.0; +import "../libraries/TransceiverHelpers.sol"; + /// @dev TransceiverRegistryBase is a base class shared between TransceiverRegistry and TransceiverRegistryAdmin. +/// It defines all of the state for the transceiver registry. This facilitates using delegate calls to implement +/// the admin functionality in a separate contract (TransceiverRegistryAdmin). abstract contract TransceiverRegistryBase { /// @dev Information about registered transceivers. struct TransceiverInfo { @@ -44,6 +48,8 @@ abstract contract TransceiverRegistryBase { bytes32 internal constant NUM_REGISTERED_TRANSCEIVERS_SLOT = bytes32(uint256(keccak256("ntt.numRegisteredTransceivers")) - 1); + // =============== Storage slot accessor functions ======================================== + function _getTransceiverInfosStorage() internal pure @@ -86,6 +92,73 @@ abstract contract TransceiverRegistryBase { $.slot := slot } } + + // + // =============== Per-chain transceiver storage =========================================== + // + struct _PerChainTransceiverData { + uint64 bitmap; + /// @dev By putting this here, we can implement the admin code in TransceiverRegistryBase rather than ManagerBase. + uint8 threshold; + } + + // =============== Storage slot for per-chain transceivers, send side ====================== + + /// @dev Holds Chain ID => Enabled send side transceiver address[] mapping. + /// mapping(uint16 => address[]). + bytes32 internal constant ENABLED_SEND_TRANSCEIVER_ARRAY_SLOT = + bytes32(uint256(keccak256("registry.sendTransceiverArray")) - 1); + + // =============== Storage slot for per-chain transceivers, receive side ================== + + /// @dev Holds Chain ID => Enabled transceiver receive side bitmap mapping. + /// mapping(uint16 => uint64). + bytes32 internal constant ENABLED_RECV_TRANSCEIVER_BITMAP_SLOT = + bytes32(uint256(keccak256("registry.recvTransceiverBitmap")) - 1); + + // =============== Storage slot for tracking enabled chains =============================== + + /// @dev Holds mapping of array of chains with transceivers enabled for sending. + bytes32 internal constant SEND_ENABLED_CHAINS_SLOT = + bytes32(uint256(keccak256("registry.sendEnabledChains")) - 1); + + /// @dev Holds mapping of array of chains with transceivers enabled for receiving. + bytes32 internal constant RECV_ENABLED_CHAINS_SLOT = + bytes32(uint256(keccak256("registry.recvEnabledChains")) - 1); + + /// @dev Chain ID => Enabled transceiver bitmap mapping. + function _getPerChainSendTransceiverArrayStorage() + internal + pure + returns (mapping(uint16 => address[]) storage $) + { + uint256 slot = uint256(ENABLED_SEND_TRANSCEIVER_ARRAY_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + + /// @dev Chain ID => Enabled transceiver bitmap mapping. + function _getPerChainRecvTransceiverDataStorage() + internal + pure + returns (mapping(uint16 => _PerChainTransceiverData) storage $) + { + uint256 slot = uint256(ENABLED_RECV_TRANSCEIVER_BITMAP_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + + /// @dev Contains all chains that have transceivers enabled. + function _getChainsEnabledStorage( + bytes32 tag + ) internal pure returns (uint16[] storage $) { + uint256 slot = uint256(tag); + assembly ("memory-safe") { + $.slot := slot + } + } } /// @title TransceiverRegistry @@ -129,11 +202,41 @@ abstract contract TransceiverRegistry is TransceiverRegistryBase { /// @param transceiver The address of the transceiver. error NonRegisteredTransceiver(address transceiver); + /// @notice Error when attempting to use an incorrect chain. + /// @dev Selector: 0x587c94c3. + /// @param chain The id of the incorrect chain. + error InvalidChain(uint16 chain); + /// @notice Error when attempting to enable a transceiver that is already enabled. - /// @dev Selector 0x8d68f84d. + /// @dev Selector: 0x8d68f84d. /// @param transceiver The address of the transceiver. error TransceiverAlreadyEnabled(address transceiver); + /// @notice Error when the transceiver is disabled. + /// @dev Selector: 0xa64030ff. + error TransceiverAlreadyDisabled(address transceiver); + + /// @notice Attempting to remove a transceiver when it is still enabled for receiving on at least one chain. + /// @dev Selector: 0x7481293a. + /// @param chain The first chain on which the transceiver is still registered. + /// @param bitmap The bitmap of enabled transceivers for that chain. + error TransceiverStillEnabledForRecv(uint16 chain, uint64 bitmap); + + /// @notice Attempting to remove a transceiver when it is still enabled for sending on at least one chain. + /// @dev Selector: 0x2bb41527. + /// @param chain The first chain on which the transceiver is still registered. + error TransceiverStillEnabledForSend(uint16 chain); + + /// @notice The threshold for transceiver attestations is too high. + /// @param chainId The chain with the invalid threshold. + /// @param threshold The threshold. + /// @param transceivers The number of transceivers. + error ThresholdTooHigh(uint16 chainId, uint256 threshold, uint256 transceivers); + + /// @notice The number of thresholds should not be zero. + /// @param chainId The chain with the invalid threshold. + error ZeroThreshold(uint16 chainId); + modifier onlyTransceiver() { if (!_getTransceiverInfosStorage()[msg.sender].enabled) { revert CallerNotTransceiver(msg.sender); @@ -164,6 +267,46 @@ abstract contract TransceiverRegistry is TransceiverRegistryBase { _checkDelegateCallRevert(success, returnData); } + function enableSendTransceiverForChain(uint16 chain, address transceiver) public { + (bool success, bytes memory returnData) = _admin.delegatecall( + abi.encodeWithSelector( + TransceiverRegistryAdmin._enableSendTransceiverForChain.selector, chain, transceiver + ) + ); + _checkDelegateCallRevert(success, returnData); + } + + function disableSendTransceiverForChain(uint16 chain, address transceiver) public { + (bool success, bytes memory returnData) = _admin.delegatecall( + abi.encodeWithSelector( + TransceiverRegistryAdmin._disableSendTransceiverForChain.selector, + chain, + transceiver + ) + ); + _checkDelegateCallRevert(success, returnData); + } + + function enableRecvTransceiverForChain(uint16 chain, address transceiver) public { + (bool success, bytes memory returnData) = _admin.delegatecall( + abi.encodeWithSelector( + TransceiverRegistryAdmin._enableRecvTransceiverForChain.selector, chain, transceiver + ) + ); + _checkDelegateCallRevert(success, returnData); + } + + function disableRecvTransceiverForChain(uint16 chain, address transceiver) public { + (bool success, bytes memory returnData) = _admin.delegatecall( + abi.encodeWithSelector( + TransceiverRegistryAdmin._disableRecvTransceiverForChain.selector, + chain, + transceiver + ) + ); + _checkDelegateCallRevert(success, returnData); + } + function _getEnabledTransceiversBitmap() internal view virtual returns (uint64 bitmap) { return _getTransceiverBitmapStorage().bitmap; } @@ -187,6 +330,65 @@ abstract contract TransceiverRegistry is TransceiverRegistryBase { return result; } + /// @notice Returns the enabled send side transceiver addresses for the given chain. + /// @param chain The Wormhole chain ID for the desired transceivers. + /// @return result The enabled send side transceivers for the given chain. + function getEnabledSendTransceiversForChain( + uint16 chain + ) public view returns (address[] memory result) { + if (chain == 0) { + revert InvalidChain(chain); + } + result = _getPerChainSendTransceiverArrayStorage()[chain]; + } + + /// @notice Returns the enabled receive side transceiver bitmap for the given chain. + /// @param chain The Wormhole chain ID for the desired transceivers. + /// @return result The enabled receive side transceiver bitmap for the given chain. + function getEnabledRecvTransceiversBitmapForChain( + uint16 chain + ) public view returns (uint64 result) { + if (chain == 0) { + revert InvalidChain(chain); + } + result = _getPerChainRecvTransceiverDataStorage()[chain].bitmap; + } + + /// @notice Returns whether or not the receive side transceiver is enabled for the given chain. + /// @dev This function is private and should only be called by a function that checks the validity of chain and transceiver. + /// @param chain The Wormhole chain ID. + /// @param index The index of the transceiver. + /// @return true if the transceiver is enabled, false otherwise. + function _isRecvTransceiverEnabledForChain( + uint16 chain, + uint8 index + ) internal view returns (bool) { + uint64 bitmap = _getPerChainRecvTransceiverDataStorage()[chain].bitmap; + return (bitmap & uint64(1 << index)) > 0; + } + + /// @notice Sets the receive threshold for the specified chain. + /// @param chain The Wormhole chain ID. + /// @param threshold The updated threshold value. + function _setThreshold(uint16 chain, uint8 threshold) internal { + (bool success, bytes memory returnData) = _admin.delegatecall( + abi.encodeWithSelector( + TransceiverRegistryAdmin._setThreshold.selector, chain, threshold + ) + ); + _checkDelegateCallRevert(success, returnData); + } + + /// @notice Returns the set of chains for which sending is enabled. + function getChainsEnabledForSending() external pure returns (uint16[] memory) { + return _getChainsEnabledStorage(SEND_ENABLED_CHAINS_SLOT); + } + + /// @notice Returns the set of chains for which receiving is enabled. + function getChainsEnabledForReceiving() external pure returns (uint16[] memory) { + return _getChainsEnabledStorage(RECV_ENABLED_CHAINS_SLOT); + } + // ============== Invariants ============================================= /// @dev Check that the transceiver nttManager is in a valid state. @@ -199,18 +401,6 @@ abstract contract TransceiverRegistry is TransceiverRegistryBase { _checkDelegateCallRevert(success, returnData); } - // @dev Check that the transceiver is in a valid state. - function _checkTransceiverInvariants( - address transceiver - ) private view { - (bool success, bytes memory returnData) = _admin.staticcall( - abi.encodeWithSelector( - TransceiverRegistryAdmin._checkTransceiverInvariants.selector, transceiver - ) - ); - _checkDelegateCallRevert(success, returnData); - } - function _checkDelegateCallRevert(bool success, bytes memory returnData) private pure { // if the function call reverted if (success == false) { @@ -222,7 +412,7 @@ abstract contract TransceiverRegistry is TransceiverRegistryBase { revert(add(32, returnData), returndata_size) } } else { - revert("_removeTransceiver reverted"); + revert("delegate call reverted"); } } } @@ -231,6 +421,41 @@ abstract contract TransceiverRegistry is TransceiverRegistryBase { /// @dev TransceiverRegistryAdmin is a helper contract to TransceiverRegistry. /// It implements admin functionality and is called via `delegatecall`. contract TransceiverRegistryAdmin is TransceiverRegistryBase { + /// @notice Emitted when a send side transceiver is enabled for a chain. + /// @dev Topic0 + /// 0x86c081420b3eb6721acf690f71cab5dea27b08f0be33f4319cdbb4a5733e7ac6. + /// @param chain The Wormhole chain ID on which this transceiver is enabled. + /// @param transceiver The address of the transceiver. + event SendTransceiverEnabledForChain(uint16 chain, address transceiver); + + /// @notice Emitted when a receive side transceiver is enabled for a chain. + /// @dev Topic0 + /// 0x6ceee5880439d670aa17a1428ce3f83fb3da492eb152aecef53eca06f0388bda. + /// @param chain The Wormhole chain ID on which this transceiver is enabled. + /// @param transceiver The address of the transceiver. + event RecvTransceiverEnabledForChain(uint16 chain, address transceiver, uint8 threshold); + + /// @notice Emitted when a send side transceiver is disabled. + /// @dev Topic0 + /// 0x6ceee5880439d670aa17a1428ce3f83fb3da492eb152aecef53eca06f0388bda. + /// @param chain The Wormhole chain ID on which this transceiver is disabled. + /// @param transceiver The address of the transceiver. + event SendTransceiverDisabledForChain(uint16 chain, address transceiver); + + /// @notice Emitted when a receive side transceiver is disabled. + /// @dev Topic0 + /// 0xdcad454c5c2805c34d9de195bc0f494aa9b7a73a7e1a3896d40004094dd9a499. + /// @param chain The Wormhole chain ID on which this transceiver is disabled. + /// @param transceiver The address of the transceiver. + event RecvTransceiverDisabledForChain(uint16 chain, address transceiver); + + /// @notice Emmitted when the threshold required transceivers is changed. + /// @dev Topic0 + /// 0x2a855b929b9a53c6fb5b5ed248b27e502b709c088e036a5aa17620c8fc5085a9. + /// @param oldThreshold The old threshold. + /// @param threshold The new threshold. + event ThresholdChanged(uint8 oldThreshold, uint8 threshold); + function _setTransceiver( address transceiver ) public returns (uint8 index) { @@ -295,6 +520,9 @@ contract TransceiverRegistryAdmin is TransceiverRegistryBase { revert TransceiverRegistry.DisabledTransceiver(transceiver); } + // Reverts if the receiver is enabled for sending or receiving on any chain. + _checkTransceiverNotEnabled(transceiver, transceiverInfos[transceiver].index); + transceiverInfos[transceiver].enabled = false; _getNumTransceiversStorage().enabled--; @@ -323,6 +551,249 @@ contract TransceiverRegistryAdmin is TransceiverRegistryBase { _checkTransceiverInvariants(transceiver); } + /// @dev Reverts if the transceiver is enabled on any chain. + /// @param transceiver The transceiver being removed. + /// @param index The index of the transceiver. + function _checkTransceiverNotEnabled(address transceiver, uint8 index) private view { + // Check the send side. + uint16[] storage chains = _getChainsEnabledStorage(SEND_ENABLED_CHAINS_SLOT); + uint256 numChains = chains.length; + for (uint256 chainIdx = 0; (chainIdx < numChains);) { + address[] storage transceivers = + _getPerChainSendTransceiverArrayStorage()[chains[chainIdx]]; + uint256 numTransceivers = transceivers.length; + for (uint256 transceiverIdx = 0; (transceiverIdx < numTransceivers);) { + if (transceivers[transceiverIdx] == transceiver) { + revert TransceiverRegistry.TransceiverStillEnabledForSend(chains[chainIdx]); + } + unchecked { + ++transceiverIdx; + } + } + + unchecked { + ++chainIdx; + } + } + + // Check the receive side. + chains = _getChainsEnabledStorage(RECV_ENABLED_CHAINS_SLOT); + numChains = chains.length; + for (uint256 idx = 0; (idx < numChains);) { + uint64 bitmap = _getPerChainRecvTransceiverDataStorage()[chains[idx]].bitmap; + if (bitmap & uint64(1 << index) != 0) { + revert TransceiverRegistry.TransceiverStillEnabledForRecv(chains[idx], bitmap); + } + unchecked { + ++idx; + } + } + } + + /// @dev This just enables the send side transceiver for a chain. It does not register it. + /// @param chain The Wormhole chain ID. + /// @param transceiver The transceiver address. + function _enableSendTransceiverForChain( + uint16 chain, + address transceiver + ) public onlyRegisteredTransceiver(chain, transceiver) { + if (_isSendTransceiverEnabledForChain(chain, transceiver)) { + revert TransceiverRegistry.TransceiverAlreadyEnabled(transceiver); + } + address[] storage sendTransceiverArray = _getPerChainSendTransceiverArrayStorage()[chain]; + if (sendTransceiverArray.length == 0) { + _addEnabledChain(SEND_ENABLED_CHAINS_SLOT, chain); + } + sendTransceiverArray.push(transceiver); + emit SendTransceiverEnabledForChain(chain, transceiver); + } + + /// @notice Disables a send side transceiver for a chain. + /// @param chain The chain ID. + /// @param transceiver The transceiver address. + function _disableSendTransceiverForChain( + uint16 chain, + address transceiver + ) public onlyRegisteredTransceiver(chain, transceiver) { + mapping(uint16 => address[]) storage enabledSendTransceivers = + _getPerChainSendTransceiverArrayStorage(); + address[] storage transceivers = enabledSendTransceivers[chain]; + + // Get the index of the disabled transceiver in the enabled transceivers array + // and replace it with the last element in the array. + uint256 len = transceivers.length; + bool found = false; + for (uint256 i = 0; i < len;) { + if (transceivers[i] == transceiver) { + // Swap the last element with the element to be removed + transceivers[i] = transceivers[len - 1]; + // Remove the last element + transceivers.pop(); + found = true; + if (transceivers.length == 0) { + _removeEnabledChain(SEND_ENABLED_CHAINS_SLOT, chain); + } + break; + } + unchecked { + ++i; + } + } + if (!found) { + revert TransceiverRegistry.TransceiverAlreadyDisabled(transceiver); + } + + emit SendTransceiverDisabledForChain(chain, transceiver); + } + + /// @dev This just enables the receive side transceiver for a chain. It does not register it. + /// @param chain The Wormhole chain ID. + /// @param transceiver The transceiver address. + function _enableRecvTransceiverForChain( + uint16 chain, + address transceiver + ) public onlyRegisteredTransceiver(chain, transceiver) { + if (_isRecvTransceiverEnabledForChain(chain, transceiver)) { + revert TransceiverRegistry.TransceiverAlreadyEnabled(transceiver); + } + uint8 index = _getTransceiverInfosStorage()[transceiver].index; + _PerChainTransceiverData storage _bitmapEntry = + _getPerChainRecvTransceiverDataStorage()[chain]; + if (_bitmapEntry.bitmap == 0) { + _addEnabledChain(RECV_ENABLED_CHAINS_SLOT, chain); + _bitmapEntry.threshold = 1; + } + _bitmapEntry.bitmap |= uint64(1 << index); + emit RecvTransceiverEnabledForChain(chain, transceiver, _bitmapEntry.threshold); + } + + /// @notice Disables a receive side transceiver for a chain. + /// @dev Will revert under the following conditions: + /// - The transceiver is the zero address. + /// - The transceiver is not registered. + /// @param chain The Wormhole chain ID. + /// @param transceiver The transceiver address. + function _disableRecvTransceiverForChain( + uint16 chain, + address transceiver + ) public onlyRegisteredTransceiver(chain, transceiver) { + mapping(address => TransceiverInfo) storage transceiverInfos = _getTransceiverInfosStorage(); + _PerChainTransceiverData storage _data = _getPerChainRecvTransceiverDataStorage()[chain]; + + uint64 updatedEnabledTransceiverBitmap = + _data.bitmap & uint64(~(1 << transceiverInfos[transceiver].index)); + // ensure that this actually changed the bitmap + if (updatedEnabledTransceiverBitmap >= _data.bitmap) { + revert TransceiverRegistry.TransceiverAlreadyDisabled(transceiver); + } + _data.bitmap = updatedEnabledTransceiverBitmap; + if (_data.bitmap == 0) { + _removeEnabledChain(RECV_ENABLED_CHAINS_SLOT, chain); + } + + emit RecvTransceiverDisabledForChain(chain, transceiver); + + uint8 numEnabled = countSetBits(_data.bitmap); + if (numEnabled < _data.threshold) { + emit ThresholdChanged(_data.threshold, numEnabled); + _data.threshold = numEnabled; + } + } + + /// @notice Returns whether or not the send side transceiver is enabled for the given chain. + /// @dev This function is private and should only be called by a function that checks the validity of chain and transceiver. + /// @param chain The Wormhole chain ID. + /// @param transceiver The transceiver address. + /// @return true if the transceiver is enabled, false otherwise. + function _isSendTransceiverEnabledForChain( + uint16 chain, + address transceiver + ) private view returns (bool) { + address[] storage transceivers = _getPerChainSendTransceiverArrayStorage()[chain]; + uint256 length = transceivers.length; + for (uint256 i = 0; i < length;) { + if (transceivers[i] == transceiver) { + return true; + } + unchecked { + ++i; + } + } + return false; + } + + /// @notice Returns whether or not the receive side transceiver is enabled for the given chain. + /// @dev This function is private and should only be called by a function that checks the validity of chain and transceiver. + /// @param chain The Wormhole chain ID. + /// @param transceiver The transceiver address. + /// @return true if the transceiver is enabled, false otherwise. + function _isRecvTransceiverEnabledForChain( + uint16 chain, + address transceiver + ) private view returns (bool) { + uint64 bitmap = _getPerChainRecvTransceiverDataStorage()[chain].bitmap; + uint8 index = _getTransceiverInfosStorage()[transceiver].index; + return (bitmap & uint64(1 << index)) > 0; + } + + /// @dev It is assumed that the chain is not already in the list. We can get away with this because the function is internal. + /// @dev Although this is a one line function, we have it for two reasons: (1) symmetry with remove, (2) simplifies testing. + /// The assumption is that the compiler will inline it anyway. + function _addEnabledChain(bytes32 tag, uint16 chain) internal { + _getChainsEnabledStorage(tag).push(chain); + } + + /// @dev It's not an error if the chain is not in the list. + function _removeEnabledChain(bytes32 tag, uint16 chain) internal { + uint16[] storage chains = _getChainsEnabledStorage(tag); + uint256 len = chains.length; + for (uint256 idx = 0; (idx < len);) { + if (chains[idx] == chain) { + chains[idx] = chains[len - 1]; + chains.pop(); + return; + } + unchecked { + ++idx; + } + } + } + + /// @notice Sets the receive threshold for the specified chain. + /// @param chain The Wormhole chain ID. + /// @param threshold The updated threshold value. + function _setThreshold(uint16 chain, uint8 threshold) public { + if (threshold == 0) { + revert TransceiverRegistry.ZeroThreshold(chain); + } + + _PerChainTransceiverData storage _data = _getPerChainRecvTransceiverDataStorage()[chain]; + uint8 oldThreshold = _data.threshold; + _data.threshold = threshold; + _checkThresholdInvariant(chain); + emit ThresholdChanged(oldThreshold, threshold); + } + + // =============== Modifiers ====================================================== + + /// @notice This modifier will revert if the transceiver is an invalid address, not registered, or the chain is invalid. + /// @param chain The Wormhole chain ID. + /// @param transceiver The transceiver address. + modifier onlyRegisteredTransceiver(uint16 chain, address transceiver) { + if (transceiver == address(0)) { + revert TransceiverRegistry.InvalidTransceiverZeroAddress(); + } + + if (chain == 0) { + revert TransceiverRegistry.InvalidChain(chain); + } + + if (!_getTransceiverInfosStorage()[transceiver].registered) { + revert TransceiverRegistry.NonRegisteredTransceiver(transceiver); + } + _; + } + /// @dev Check that the transceiver nttManager is in a valid state. /// Checking these invariants is somewhat costly, but we only need to do it /// when modifying the transceivers, which happens infrequently. @@ -346,6 +817,8 @@ contract TransceiverRegistryAdmin is TransceiverRegistryBase { // invariant: numRegisteredTransceivers <= MAX_TRANSCEIVERS assert(_numTransceivers.registered <= MAX_TRANSCEIVERS); + + _checkPerChainTransceiversInvariants(); } // @dev Check that the transceiver is in a valid state. @@ -359,7 +832,7 @@ contract TransceiverRegistryAdmin is TransceiverRegistryBase { TransceiverInfo memory transceiverInfo = transceiverInfos[transceiver]; - // if an transceiver is not registered, it should not be enabled + // if a transceiver is not registered, it should not be enabled assert( transceiverInfo.registered || (!transceiverInfo.enabled && transceiverInfo.index == 0) ); @@ -385,5 +858,50 @@ contract TransceiverRegistryAdmin is TransceiverRegistryBase { assert(transceiverInEnabledTransceivers == transceiverEnabled); assert(transceiverInfo.index < _numTransceivers.registered); + + _checkThresholdInvariants(); + } + + function _checkPerChainTransceiversInvariants() internal pure { + // Send side + uint16[] memory chains = _getChainsEnabledStorage(SEND_ENABLED_CHAINS_SLOT); + uint256 len = chains.length; + for (uint256 idx = 0; (idx < len);) { + // Make sure there is an enabled transceiver for this chain. + unchecked { + ++idx; + } + } + } + + function _checkThresholdInvariants() public view { + uint16[] storage chains = _getChainsEnabledStorage(RECV_ENABLED_CHAINS_SLOT); + uint256 len = chains.length; + for (uint256 idx = 0; (idx < len);) { + _checkThresholdInvariant(chains[idx]); + unchecked { + ++idx; + } + } + } + + function _checkThresholdInvariant( + uint16 chain + ) internal view { + // mapping(address => TransceiverInfo) storage transceiverInfos = _getTransceiverInfosStorage(); + _PerChainTransceiverData storage _data = _getPerChainRecvTransceiverDataStorage()[chain]; + + uint8 numEnabled = countSetBits(_data.bitmap); + + // invariant: threshold <= enabledTransceivers.length + if (_data.threshold > numEnabled) { + revert TransceiverRegistry.ThresholdTooHigh(chain, _data.threshold, numEnabled); + } + + if (numEnabled > 0) { + if (_data.threshold == 0) { + revert TransceiverRegistry.ZeroThreshold(chain); + } + } } } diff --git a/evm/src/interfaces/IManagerBase.sol b/evm/src/interfaces/IManagerBase.sol index c4d1d565f..ca9d4ed1a 100644 --- a/evm/src/interfaces/IManagerBase.sol +++ b/evm/src/interfaces/IManagerBase.sol @@ -39,27 +39,18 @@ interface IManagerBase { /// @param index The index of the transceiver in the bitmap. event MessageAttestedTo(bytes32 digest, address transceiver, uint8 index); - /// @notice Emmitted when the threshold required transceivers is changed. - /// @dev Topic0 - /// 0x2a855b929b9a53c6fb5b5ed248b27e502b709c088e036a5aa17620c8fc5085a9. - /// @param oldThreshold The old threshold. - /// @param threshold The new threshold. - event ThresholdChanged(uint8 oldThreshold, uint8 threshold); - /// @notice Emitted when an transceiver is removed from the nttManager. /// @dev Topic0 - /// 0xf05962b5774c658e85ed80c91a75af9d66d2af2253dda480f90bce78aff5eda5. + /// 0x2fb241a51a63da05063ac6be1f963395b281e455e8085bd246a7e8502b8950d5. /// @param transceiver The address of the transceiver. /// @param transceiversNum The current number of transceivers. - /// @param threshold The current threshold of transceivers. - event TransceiverAdded(address transceiver, uint256 transceiversNum, uint8 threshold); + event TransceiverAdded(address transceiver, uint256 transceiversNum); /// @notice Emitted when an transceiver is removed from the nttManager. /// @dev Topic0 - /// 0x697a3853515b88013ad432f29f53d406debc9509ed6d9313dcfe115250fcd18f. + /// 0x61c80f4ebc728c94f4f766219fd399e4622bcee9d724b0910f59cf1cd544aac7. /// @param transceiver The address of the transceiver. - /// @param threshold The current threshold of transceivers. - event TransceiverRemoved(address transceiver, uint8 threshold); + event TransceiverRemoved(address transceiver); /// @notice payment for a transfer is too low. /// @param requiredPayment The required payment. @@ -71,16 +62,8 @@ interface IManagerBase { /// @param refundAmount The refund amount. error RefundFailed(uint256 refundAmount); - /// @notice The number of thresholds should not be zero. - error ZeroThreshold(); - error RetrievedIncorrectRegisteredTransceivers(uint256 retrieved, uint256 registered); - /// @notice The threshold for transceiver attestations is too high. - /// @param threshold The threshold. - /// @param transceivers The number of transceivers. - error ThresholdTooHigh(uint256 threshold, uint256 transceivers); - /// @notice Error when the tranceiver already attested to the message. /// To ensure the client does not continue to initiate calls to the attestationReceived function. /// @dev Selector 0x2113894. @@ -120,11 +103,10 @@ interface IManagerBase { /// @notice Sets the threshold for the number of attestations required for a message /// to be considered valid. + /// @param sourceChainId The chain ID to which the threshold applies. /// @param threshold The new threshold (number of attestations). /// @dev This method can only be executed by the `owner`. - function setThreshold( - uint8 threshold - ) external; + function setThreshold(uint16 sourceChainId, uint8 threshold) external; /// @notice Sets the transceiver for the given chain. /// @param transceiver The address of the transceiver. @@ -142,11 +124,10 @@ interface IManagerBase { /// @notice Checks if a message has been approved. The message should have at least /// the minimum threshold of attestations from distinct endpoints. + /// @param sourceChainId The chain ID from which the message was received. /// @param digest The digest of the message. /// @return - Boolean indicating if message has been approved. - function isMessageApproved( - bytes32 digest - ) external view returns (bool); + function isMessageApproved(uint16 sourceChainId, bytes32 digest) external view returns (bool); /// @notice Checks if a message has been executed. /// @param digest The digest of the message. @@ -175,7 +156,10 @@ interface IManagerBase { /// @notice Returns the number of Transceivers that must attest to a msgId for /// it to be considered valid and acted upon. - function getThreshold() external view returns (uint8); + /// @param chainId The chain for which the threshold applies. + function getThreshold( + uint16 chainId + ) external view returns (uint8); /// @notice Returns a boolean indicating if the transceiver has attested to the message. /// @param digest The digest of the message. diff --git a/evm/test/IntegrationAdditionalTransfer.t.sol b/evm/test/IntegrationAdditionalTransfer.t.sol index b281c4abb..b6f50cdf7 100755 --- a/evm/test/IntegrationAdditionalTransfer.t.sol +++ b/evm/test/IntegrationAdditionalTransfer.t.sol @@ -16,6 +16,7 @@ import {DummyToken, DummyTokenMintAndBurn} from "./NttManager.t.sol"; import "../src/interfaces/IWormholeTransceiver.sol"; import {WormholeTransceiver} from "../src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol"; import "../src/libraries/TransceiverStructs.sol"; +import "./libraries/TransceiverHelpers.sol"; import "./mocks/MockNttManagerAdditionalPayload.sol"; import "./mocks/MockTransceivers.sol"; @@ -94,7 +95,9 @@ contract TestAdditionalPayload is Test { // Actually initialize properly now wormholeTransceiverChain1.initialize(); - nttManagerChain1.setTransceiver(address(wormholeTransceiverChain1)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain1, chainId2, address(wormholeTransceiverChain1) + ); // nttManagerChain1.setOutboundLimit(type(uint64).max); // nttManagerChain1.setInboundLimit(type(uint64).max, chainId2); @@ -123,7 +126,9 @@ contract TestAdditionalPayload is Test { ); wormholeTransceiverChain2.initialize(); - nttManagerChain2.setTransceiver(address(wormholeTransceiverChain2)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain2, chainId1, address(wormholeTransceiverChain2) + ); // nttManagerChain2.setOutboundLimit(type(uint64).max); // nttManagerChain2.setInboundLimit(type(uint64).max, chainId1); @@ -143,11 +148,14 @@ contract TestAdditionalPayload is Test { chainId1, bytes32(uint256(uint160(address(wormholeTransceiverChain1)))) ); - require(nttManagerChain1.getThreshold() != 0, "Threshold is zero with active transceivers"); + require( + nttManagerChain1.getThreshold(chainId2) != 0, + "Threshold is zero with active transceivers" + ); // Actually set it - nttManagerChain1.setThreshold(1); - nttManagerChain2.setThreshold(1); + nttManagerChain1.setThreshold(chainId2, 1); + nttManagerChain2.setThreshold(chainId1, 1); INttManager.NttManagerPeer memory peer = nttManagerChain1.getPeer(chainId2); require(9 == peer.tokenDecimals, "Peer has the wrong number of token decimals"); diff --git a/evm/test/IntegrationManual.t.sol b/evm/test/IntegrationManual.t.sol index 56d762996..2adaefeca 100644 --- a/evm/test/IntegrationManual.t.sol +++ b/evm/test/IntegrationManual.t.sol @@ -7,6 +7,7 @@ import {WormholeRelayerBasicTest} from "wormhole-solidity-sdk/testing/WormholeRe import "./libraries/IntegrationHelpers.sol"; import "wormhole-solidity-sdk/testing/helpers/WormholeSimulator.sol"; import "../src/NttManager/NttManager.sol"; +import "./libraries/TransceiverHelpers.sol"; import "./mocks/MockNttManager.sol"; import "./mocks/MockTransceivers.sol"; @@ -67,7 +68,9 @@ contract TestRelayerEndToEndManual is IntegrationHelpers, IRateLimiterEvents { ); wormholeTransceiverChain1.initialize(); - nttManagerChain1.setTransceiver(address(wormholeTransceiverChain1)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain1, chainId2, address(wormholeTransceiverChain1) + ); nttManagerChain1.setOutboundLimit(type(uint64).max); nttManagerChain1.setInboundLimit(type(uint64).max, chainId2); @@ -94,7 +97,9 @@ contract TestRelayerEndToEndManual is IntegrationHelpers, IRateLimiterEvents { ); wormholeTransceiverChain2.initialize(); - nttManagerChain2.setTransceiver(address(wormholeTransceiverChain2)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain2, chainId1, address(wormholeTransceiverChain2) + ); nttManagerChain2.setOutboundLimit(type(uint64).max); nttManagerChain2.setInboundLimit(type(uint64).max, chainId1); diff --git a/evm/test/IntegrationRelayer.t.sol b/evm/test/IntegrationRelayer.t.sol index b1b38b662..cdca986e9 100755 --- a/evm/test/IntegrationRelayer.t.sol +++ b/evm/test/IntegrationRelayer.t.sol @@ -93,11 +93,15 @@ contract TestEndToEndRelayer is IntegrationHelpers, IRateLimiterEvents, Wormhole ); wormholeTransceiverChain1Other.initialize(); - nttManagerChain1.setTransceiver(address(wormholeTransceiverChain1)); - nttManagerChain1.setTransceiver(address(wormholeTransceiverChain1Other)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain1, chainId2, address(wormholeTransceiverChain1) + ); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain1, chainId2, address(wormholeTransceiverChain1Other) + ); nttManagerChain1.setOutboundLimit(type(uint64).max); nttManagerChain1.setInboundLimit(type(uint64).max, chainId2); - nttManagerChain1.setThreshold(1); + nttManagerChain1.setThreshold(chainId2, 1); } // Setup the chain to relay to of the network @@ -141,12 +145,16 @@ contract TestEndToEndRelayer is IntegrationHelpers, IRateLimiterEvents, Wormhole ); wormholeTransceiverChain2Other.initialize(); - nttManagerChain2.setTransceiver(address(wormholeTransceiverChain2)); - nttManagerChain2.setTransceiver(address(wormholeTransceiverChain2Other)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain2, chainId1, address(wormholeTransceiverChain2) + ); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain2, chainId1, address(wormholeTransceiverChain2Other) + ); nttManagerChain2.setOutboundLimit(type(uint64).max); nttManagerChain2.setInboundLimit(type(uint64).max, chainId1); - nttManagerChain2.setThreshold(1); + nttManagerChain2.setThreshold(chainId1, 1); } function test_chainToChainReverts() public { diff --git a/evm/test/IntegrationStandalone.t.sol b/evm/test/IntegrationStandalone.t.sol index ecaa2b310..4b88132e0 100755 --- a/evm/test/IntegrationStandalone.t.sol +++ b/evm/test/IntegrationStandalone.t.sol @@ -16,6 +16,7 @@ import {DummyToken, DummyTokenMintAndBurn} from "./NttManager.t.sol"; import "../src/interfaces/IWormholeTransceiver.sol"; import {WormholeTransceiver} from "../src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol"; import "../src/libraries/TransceiverStructs.sol"; +import "./libraries/TransceiverHelpers.sol"; import "./mocks/MockNttManager.sol"; import "./mocks/MockTransceivers.sol"; @@ -93,7 +94,9 @@ contract TestEndToEndBase is Test, IRateLimiterEvents { // Actually initialize properly now wormholeTransceiverChain1.initialize(); - nttManagerChain1.setTransceiver(address(wormholeTransceiverChain1)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain1, chainId2, address(wormholeTransceiverChain1) + ); nttManagerChain1.setOutboundLimit(type(uint64).max); nttManagerChain1.setInboundLimit(type(uint64).max, chainId2); @@ -121,7 +124,9 @@ contract TestEndToEndBase is Test, IRateLimiterEvents { ); wormholeTransceiverChain2.initialize(); - nttManagerChain2.setTransceiver(address(wormholeTransceiverChain2)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain2, chainId1, address(wormholeTransceiverChain2) + ); nttManagerChain2.setOutboundLimit(type(uint64).max); nttManagerChain2.setInboundLimit(type(uint64).max, chainId1); @@ -141,11 +146,14 @@ contract TestEndToEndBase is Test, IRateLimiterEvents { chainId1, bytes32(uint256(uint160(address(wormholeTransceiverChain1)))) ); - require(nttManagerChain1.getThreshold() != 0, "Threshold is zero with active transceivers"); + require( + nttManagerChain1.getThreshold(chainId2) != 0, + "Threshold is zero with active transceivers" + ); // Actually set it - nttManagerChain1.setThreshold(1); - nttManagerChain2.setThreshold(1); + nttManagerChain1.setThreshold(chainId2, 1); + nttManagerChain2.setThreshold(chainId1, 1); } function test_chainToChainBase() public { @@ -536,12 +544,16 @@ contract TestEndToEndBase is Test, IRateLimiterEvents { wormholeTransceiverChain2_2.setWormholePeer( chainId1, bytes32(uint256(uint160((address(wormholeTransceiverChain1_2))))) ); - nttManagerChain2.setTransceiver(address(wormholeTransceiverChain2_2)); - nttManagerChain1.setTransceiver(address(wormholeTransceiverChain1_2)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain2, chainId1, address(wormholeTransceiverChain2_2) + ); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain1, chainId2, address(wormholeTransceiverChain1_2) + ); // Change the threshold from the setUp functions 1 to 2. - nttManagerChain1.setThreshold(2); - nttManagerChain2.setThreshold(2); + nttManagerChain1.setThreshold(chainId2, 2); + nttManagerChain2.setThreshold(chainId1, 2); // Setting up the transfer DummyToken token1 = DummyToken(nttManagerChain1.token()); diff --git a/evm/test/IntegrationWithoutRateLimiting.t.sol b/evm/test/IntegrationWithoutRateLimiting.t.sol index 8d7e5d0fe..a9d7e41ba 100755 --- a/evm/test/IntegrationWithoutRateLimiting.t.sol +++ b/evm/test/IntegrationWithoutRateLimiting.t.sol @@ -16,6 +16,7 @@ import {DummyToken, DummyTokenMintAndBurn} from "./NttManager.t.sol"; import "../src/interfaces/IWormholeTransceiver.sol"; import {WormholeTransceiver} from "../src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol"; import "../src/libraries/TransceiverStructs.sol"; +import "./libraries/TransceiverHelpers.sol"; import "./mocks/MockNttManager.sol"; import "./mocks/MockTransceivers.sol"; @@ -94,7 +95,9 @@ contract TestEndToEndNoRateLimiting is Test { // Actually initialize properly now wormholeTransceiverChain1.initialize(); - nttManagerChain1.setTransceiver(address(wormholeTransceiverChain1)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain1, chainId2, address(wormholeTransceiverChain1) + ); // nttManagerChain1.setOutboundLimit(type(uint64).max); // nttManagerChain1.setInboundLimit(type(uint64).max, chainId2); @@ -123,7 +126,9 @@ contract TestEndToEndNoRateLimiting is Test { ); wormholeTransceiverChain2.initialize(); - nttManagerChain2.setTransceiver(address(wormholeTransceiverChain2)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain2, chainId1, address(wormholeTransceiverChain2) + ); // nttManagerChain2.setOutboundLimit(type(uint64).max); // nttManagerChain2.setInboundLimit(type(uint64).max, chainId1); @@ -143,11 +148,14 @@ contract TestEndToEndNoRateLimiting is Test { chainId1, bytes32(uint256(uint160(address(wormholeTransceiverChain1)))) ); - require(nttManagerChain1.getThreshold() != 0, "Threshold is zero with active transceivers"); + require( + nttManagerChain1.getThreshold(chainId2) != 0, + "Threshold is zero with active transceivers" + ); // Actually set it - nttManagerChain1.setThreshold(1); - nttManagerChain2.setThreshold(1); + nttManagerChain1.setThreshold(chainId2, 1); + nttManagerChain2.setThreshold(chainId1, 1); INttManager.NttManagerPeer memory peer = nttManagerChain1.getPeer(chainId2); require(9 == peer.tokenDecimals, "Peer has the wrong number of token decimals"); @@ -524,12 +532,16 @@ contract TestEndToEndNoRateLimiting is Test { wormholeTransceiverChain2_2.setWormholePeer( chainId1, bytes32(uint256(uint160((address(wormholeTransceiverChain1_2))))) ); - nttManagerChain2.setTransceiver(address(wormholeTransceiverChain2_2)); - nttManagerChain1.setTransceiver(address(wormholeTransceiverChain1_2)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain2, chainId1, address(wormholeTransceiverChain2_2) + ); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain1, chainId2, address(wormholeTransceiverChain1_2) + ); // Change the threshold from the setUp functions 1 to 2. - nttManagerChain1.setThreshold(2); - nttManagerChain2.setThreshold(2); + nttManagerChain1.setThreshold(chainId2, 2); + nttManagerChain2.setThreshold(chainId1, 2); // Setting up the transfer DummyToken token1 = DummyToken(nttManagerChain1.token()); diff --git a/evm/test/NttManager.t.sol b/evm/test/NttManager.t.sol index 171940e0d..ab77b48a1 100644 --- a/evm/test/NttManager.t.sol +++ b/evm/test/NttManager.t.sol @@ -69,7 +69,9 @@ contract TestNttManager is Test, IRateLimiterEvents { nttManagerOther.initialize(); dummyTransceiver = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(dummyTransceiver)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManager, chainId2, address(dummyTransceiver) + ); } // === pure unit tests @@ -106,7 +108,9 @@ contract TestNttManager is Test, IRateLimiterEvents { nttManagerZeroRateLimiter.initialize(); DummyTransceiver e = new DummyTransceiver(address(nttManagerZeroRateLimiter)); - nttManagerZeroRateLimiter.setTransceiver(address(e)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerZeroRateLimiter, chainId2, address(e) + ); address user_A = address(0x123); address user_B = address(0x456); @@ -139,8 +143,10 @@ contract TestNttManager is Test, IRateLimiterEvents { assertEq(s3, 2); // Test incoming transfer completes successfully with rate limit disabled - (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers(nttManagerZeroRateLimiter); - nttManagerZeroRateLimiter.setThreshold(2); + (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManagerZeroRateLimiter + ); + nttManagerZeroRateLimiter.setThreshold(TransceiverHelpersLib.SENDING_CHAIN_ID, 2); // register nttManager peer bytes32 peer = toWormholeFormat(address(nttManager)); @@ -301,25 +307,42 @@ contract TestNttManager is Test, IRateLimiterEvents { function test_cantEnableTransceiverTwice() public { DummyTransceiver e = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(e)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e)); vm.expectRevert( abi.encodeWithSelector( TransceiverRegistry.TransceiverAlreadyEnabled.selector, address(e) ) ); - nttManager.setTransceiver(address(e)); + nttManager.enableSendTransceiverForChain(chainId2, address(e)); } function test_disableReenableTransceiver() public { DummyTransceiver e = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(e)); - nttManager.removeTransceiver(address(e)); - nttManager.setTransceiver(address(e)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e)); + nttManager.disableSendTransceiverForChain(chainId2, address(e)); + nttManager.disableRecvTransceiverForChain(chainId2, address(e)); + nttManager.enableSendTransceiverForChain(chainId2, address(e)); + nttManager.enableRecvTransceiverForChain(chainId2, address(e)); } - function test_disableAllTransceiversFails() public { - vm.expectRevert(abi.encodeWithSelector(IManagerBase.ZeroThreshold.selector)); + function test_cantRemoveTransceiverWhileEnabled() public { + vm.expectRevert( + abi.encodeWithSelector( + TransceiverRegistry.TransceiverStillEnabledForSend.selector, chainId2 + ) + ); + nttManager.removeTransceiver(address(dummyTransceiver)); + + nttManager.disableSendTransceiverForChain(chainId2, address(dummyTransceiver)); + vm.expectRevert( + abi.encodeWithSelector( + TransceiverRegistry.TransceiverStillEnabledForRecv.selector, chainId2, 1 + ) + ); + nttManager.removeTransceiver(address(dummyTransceiver)); + + nttManager.disableRecvTransceiverForChain(chainId2, address(dummyTransceiver)); nttManager.removeTransceiver(address(dummyTransceiver)); } @@ -327,8 +350,8 @@ contract TestNttManager is Test, IRateLimiterEvents { DummyTransceiver e1 = new DummyTransceiver(address(nttManager)); DummyTransceiver e2 = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(e1)); - nttManager.setTransceiver(address(e2)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e1)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e2)); } function test_transceiverIncompatibleNttManager() public { @@ -394,13 +417,13 @@ contract TestNttManager is Test, IRateLimiterEvents { // Let's register a transceiver and then disable it. We now have 2 registered managers // since we register 1 in the setup DummyTransceiver e = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(e)); - nttManager.removeTransceiver(address(e)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e)); + TransceiverHelpersLib.disableAndRemoveTransceiver(nttManager, chainId2, address(e)); // We should be able to register 64 transceivers total for (uint256 i = 0; i < 62; ++i) { DummyTransceiver d = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(d)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(d)); } // Registering a new transceiver should fail as we've hit the cap @@ -408,7 +431,7 @@ contract TestNttManager is Test, IRateLimiterEvents { vm.expectRevert(TransceiverRegistry.TooManyTransceivers.selector); nttManager.setTransceiver(address(c)); - // We should be able to renable an already registered transceiver at the cap + // We should be able to register an already registered transceiver at the cap nttManager.setTransceiver(address(e)); } @@ -416,8 +439,10 @@ contract TestNttManager is Test, IRateLimiterEvents { // Let's register a transceiver and then disable the original transceiver. We now have 2 registered transceivers // since we register 1 in the setup DummyTransceiver e = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(e)); - nttManager.removeTransceiver(address(dummyTransceiver)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e)); + TransceiverHelpersLib.disableAndRemoveTransceiver( + nttManager, chainId2, address(dummyTransceiver) + ); address user_A = address(0x123); address user_B = address(0x456); @@ -527,27 +552,31 @@ contract TestNttManager is Test, IRateLimiterEvents { function test_cantSetThresholdTooHigh() public { // 1 transceiver set, so can't set threshold to 2 - vm.expectRevert(abi.encodeWithSelector(IManagerBase.ThresholdTooHigh.selector, 2, 1)); - nttManager.setThreshold(2); + vm.expectRevert( + abi.encodeWithSelector(TransceiverRegistry.ThresholdTooHigh.selector, chainId2, 2, 1) + ); + nttManager.setThreshold(chainId2, 2); } function test_canSetThreshold() public { DummyTransceiver e1 = new DummyTransceiver(address(nttManager)); DummyTransceiver e2 = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(e1)); - nttManager.setTransceiver(address(e2)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e1)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e2)); - nttManager.setThreshold(1); - nttManager.setThreshold(2); - nttManager.setThreshold(1); + nttManager.setThreshold(chainId2, 1); + nttManager.setThreshold(chainId2, 2); + nttManager.setThreshold(chainId2, 1); } function test_cantSetThresholdToZero() public { DummyTransceiver e = new DummyTransceiver(address(nttManager)); nttManager.setTransceiver(address(e)); - vm.expectRevert(abi.encodeWithSelector(IManagerBase.ZeroThreshold.selector)); - nttManager.setThreshold(0); + vm.expectRevert( + abi.encodeWithSelector(TransceiverRegistry.ZeroThreshold.selector, chainId2) + ); + nttManager.setThreshold(chainId2, 0); } function test_onlyOwnerCanSetThreshold() public { @@ -557,7 +586,7 @@ contract TestNttManager is Test, IRateLimiterEvents { vm.expectRevert( abi.encodeWithSelector(OwnableUpgradeable.OwnableUnauthorizedAccount.selector, notOwner) ); - nttManager.setThreshold(1); + nttManager.setThreshold(chainId2, 1); } // == threshold @@ -580,8 +609,12 @@ contract TestNttManager is Test, IRateLimiterEvents { // === attestation function test_onlyEnabledTransceiversCanAttest() public { - (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers(nttManagerOther); - nttManagerOther.removeTransceiver(address(e1)); + (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManagerOther + ); + nttManagerOther.disableRecvTransceiverForChain( + TransceiverHelpersLib.SENDING_CHAIN_ID, address(e1) + ); bytes32 peer = toWormholeFormat(address(nttManager)); nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9, type(uint64).max); @@ -597,8 +630,10 @@ contract TestNttManager is Test, IRateLimiterEvents { } function test_onlyPeerNttManagerCanAttest() public { - (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers(nttManagerOther); - nttManagerOther.setThreshold(2); + (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManagerOther + ); + nttManagerOther.setThreshold(TransceiverHelpersLib.SENDING_CHAIN_ID, 2); bytes32 peer = toWormholeFormat(address(nttManager)); @@ -618,8 +653,10 @@ contract TestNttManager is Test, IRateLimiterEvents { } function test_attestSimple() public { - (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers(nttManagerOther); - nttManagerOther.setThreshold(2); + (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManagerOther + ); + nttManagerOther.setThreshold(TransceiverHelpersLib.SENDING_CHAIN_ID, 2); // register nttManager peer bytes32 peer = toWormholeFormat(address(nttManager)); @@ -641,8 +678,10 @@ contract TestNttManager is Test, IRateLimiterEvents { } function test_attestTwice() public { - (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers(nttManagerOther); - nttManagerOther.setThreshold(2); + (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManagerOther + ); + nttManagerOther.setThreshold(TransceiverHelpersLib.SENDING_CHAIN_ID, 2); // register nttManager peer bytes32 peer = toWormholeFormat(address(nttManager)); @@ -670,8 +709,9 @@ contract TestNttManager is Test, IRateLimiterEvents { } function test_attestDisabled() public { - (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers(nttManagerOther); - nttManagerOther.setThreshold(2); + DummyTransceiver e1 = TransceiverHelpersLib.setup_one_transceiver( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManagerOther + ); bytes32 peer = toWormholeFormat(address(nttManager)); nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9, type(uint64).max); @@ -691,7 +731,9 @@ contract TestNttManager is Test, IRateLimiterEvents { transceivers ); - nttManagerOther.removeTransceiver(address(e1)); + TransceiverHelpersLib.disableAndRemoveTransceiver( + nttManagerOther, TransceiverHelpersLib.SENDING_CHAIN_ID, address(e1) + ); bytes32 hash = TransceiverStructs.nttManagerMessageDigest(TransceiverHelpersLib.SENDING_CHAIN_ID, m); @@ -796,8 +838,9 @@ contract TestNttManager is Test, IRateLimiterEvents { function test_attestationQuorum() public { address user_B = address(0x456); - (DummyTransceiver e1, DummyTransceiver e2) = - TransceiverHelpersLib.setup_transceivers(nttManagerOther); + (DummyTransceiver e1, DummyTransceiver e2) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManagerOther + ); TrimmedAmount transferAmount = packTrimmedAmount(50, 8); @@ -843,6 +886,13 @@ contract TestNttManager is Test, IRateLimiterEvents { } function test_transfersOnForkedChains() public { + nttManager.enableSendTransceiverForChain( + TransceiverHelpersLib.SENDING_CHAIN_ID, address(dummyTransceiver) + ); + nttManager.enableRecvTransceiverForChain( + TransceiverHelpersLib.SENDING_CHAIN_ID, address(dummyTransceiver) + ); + uint256 evmChainId = block.chainid; address user_A = address(0x123); @@ -1057,8 +1107,8 @@ contract TestNttManager is Test, IRateLimiterEvents { address user_B = address(0x456); DummyToken token = DummyToken(nttManager.token()); TrimmedAmount transferAmount = packTrimmedAmount(50, 8); - (ITransceiverReceiver e1, ITransceiverReceiver e2) = - TransceiverHelpersLib.setup_transceivers(nttManagerOther); + (ITransceiverReceiver e1, ITransceiverReceiver e2) = TransceiverHelpersLib + .setup_transceivers(TransceiverHelpersLib.SENDING_CHAIN_ID, nttManagerOther); // Step 1 (contract is deployed by setUp()) ITransceiverReceiver[] memory transceivers = new ITransceiverReceiver[](2); @@ -1126,7 +1176,9 @@ contract TestNttManager is Test, IRateLimiterEvents { newNttManager.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9, type(uint64).max); { DummyTransceiver e = new DummyTransceiver(address(newNttManager)); - newNttManager.setTransceiver(address(e)); + TransceiverHelpersLib.setAndEnableTransceiver( + newNttManager, TransceiverHelpersLib.SENDING_CHAIN_ID, address(e) + ); } address user_A = address(0x123); @@ -1147,8 +1199,10 @@ contract TestNttManager is Test, IRateLimiterEvents { vm.stopPrank(); // Check that we can receive a transfer - (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers(newNttManager); - newNttManager.setThreshold(1); + (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, newNttManager + ); + newNttManager.setThreshold(TransceiverHelpersLib.SENDING_CHAIN_ID, 1); bytes memory transceiverMessage; bytes memory tokenTransferMessage; diff --git a/evm/test/NttManagerNoRateLimiting.t.sol b/evm/test/NttManagerNoRateLimiting.t.sol index 9df4d4d80..73a5c5727 100644 --- a/evm/test/NttManagerNoRateLimiting.t.sol +++ b/evm/test/NttManagerNoRateLimiting.t.sol @@ -70,7 +70,9 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { nttManagerOther.initialize(); dummyTransceiver = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(dummyTransceiver)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManager, chainId2, address(dummyTransceiver) + ); } // === pure unit tests @@ -206,7 +208,7 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { function test_registerTransceiver() public { DummyTransceiver e = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(e)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e)); } function test_onlyOwnerCanModifyTransceivers() public { @@ -229,34 +231,31 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { function test_cantEnableTransceiverTwice() public { DummyTransceiver e = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(e)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e)); vm.expectRevert( abi.encodeWithSelector( TransceiverRegistry.TransceiverAlreadyEnabled.selector, address(e) ) ); - nttManager.setTransceiver(address(e)); + nttManager.enableSendTransceiverForChain(chainId2, address(e)); } function test_disableReenableTransceiver() public { DummyTransceiver e = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(e)); - nttManager.removeTransceiver(address(e)); - nttManager.setTransceiver(address(e)); - } - - function test_disableAllTransceiversFails() public { - vm.expectRevert(abi.encodeWithSelector(IManagerBase.ZeroThreshold.selector)); - nttManager.removeTransceiver(address(dummyTransceiver)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e)); + nttManager.disableSendTransceiverForChain(chainId2, address(e)); + nttManager.disableRecvTransceiverForChain(chainId2, address(e)); + nttManager.enableSendTransceiverForChain(chainId2, address(e)); + nttManager.enableRecvTransceiverForChain(chainId2, address(e)); } function test_multipleTransceivers() public { DummyTransceiver e1 = new DummyTransceiver(address(nttManager)); DummyTransceiver e2 = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(e1)); - nttManager.setTransceiver(address(e2)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e1)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e2)); } function test_transceiverIncompatibleNttManagerNoRateLimiting() public { @@ -325,13 +324,13 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { // Let's register a transceiver and then disable it. We now have 2 registered managers // since we register 1 in the setup DummyTransceiver e = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(e)); - nttManager.removeTransceiver(address(e)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e)); + TransceiverHelpersLib.disableAndRemoveTransceiver(nttManager, chainId2, address(e)); // We should be able to register 64 transceivers total for (uint256 i = 0; i < 62; ++i) { DummyTransceiver d = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(d)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(d)); } // Registering a new transceiver should fail as we've hit the cap @@ -339,7 +338,7 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { vm.expectRevert(TransceiverRegistry.TooManyTransceivers.selector); nttManager.setTransceiver(address(c)); - // We should be able to renable an already registered transceiver at the cap + // We should be able to register an already registered transceiver at the cap nttManager.setTransceiver(address(e)); } @@ -347,8 +346,10 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { // Let's register a transceiver and then disable the original transceiver. We now have 2 registered transceivers // since we register 1 in the setup DummyTransceiver e = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(e)); - nttManager.removeTransceiver(address(dummyTransceiver)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e)); + TransceiverHelpersLib.disableAndRemoveTransceiver( + nttManager, chainId2, address(dummyTransceiver) + ); address user_A = address(0x123); address user_B = address(0x456); @@ -389,27 +390,31 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { function test_cantSetThresholdTooHigh() public { // 1 transceiver set, so can't set threshold to 2 - vm.expectRevert(abi.encodeWithSelector(IManagerBase.ThresholdTooHigh.selector, 2, 1)); - nttManager.setThreshold(2); + vm.expectRevert( + abi.encodeWithSelector(TransceiverRegistry.ThresholdTooHigh.selector, chainId2, 2, 1) + ); + nttManager.setThreshold(chainId2, 2); } function test_canSetThreshold() public { DummyTransceiver e1 = new DummyTransceiver(address(nttManager)); DummyTransceiver e2 = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(e1)); - nttManager.setTransceiver(address(e2)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e1)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e2)); - nttManager.setThreshold(1); - nttManager.setThreshold(2); - nttManager.setThreshold(1); + nttManager.setThreshold(chainId2, 1); + nttManager.setThreshold(chainId2, 2); + nttManager.setThreshold(chainId2, 1); } function test_cantSetThresholdToZero() public { DummyTransceiver e = new DummyTransceiver(address(nttManager)); nttManager.setTransceiver(address(e)); - vm.expectRevert(abi.encodeWithSelector(IManagerBase.ZeroThreshold.selector)); - nttManager.setThreshold(0); + vm.expectRevert( + abi.encodeWithSelector(TransceiverRegistry.ZeroThreshold.selector, chainId2) + ); + nttManager.setThreshold(chainId2, 0); } function test_onlyOwnerCanSetThreshold() public { @@ -419,7 +424,7 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { vm.expectRevert( abi.encodeWithSelector(OwnableUpgradeable.OwnableUnauthorizedAccount.selector, notOwner) ); - nttManager.setThreshold(1); + nttManager.setThreshold(chainId2, 1); } // == threshold @@ -442,8 +447,12 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { // === attestation function test_onlyEnabledTransceiversCanAttest() public { - (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers(nttManagerOther); - nttManagerOther.removeTransceiver(address(e1)); + (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManagerOther + ); + nttManagerOther.disableRecvTransceiverForChain( + TransceiverHelpersLib.SENDING_CHAIN_ID, address(e1) + ); bytes32 peer = toWormholeFormat(address(nttManager)); nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9, type(uint64).max); @@ -459,8 +468,10 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { } function test_onlyPeerNttManagerNoRateLimitingCanAttest() public { - (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers(nttManagerOther); - nttManagerOther.setThreshold(2); + (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManagerOther + ); + nttManagerOther.setThreshold(TransceiverHelpersLib.SENDING_CHAIN_ID, 2); bytes32 peer = toWormholeFormat(address(nttManager)); @@ -480,8 +491,10 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { } function test_attestSimple() public { - (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers(nttManagerOther); - nttManagerOther.setThreshold(2); + (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManagerOther + ); + nttManagerOther.setThreshold(TransceiverHelpersLib.SENDING_CHAIN_ID, 2); // register nttManager peer bytes32 peer = toWormholeFormat(address(nttManager)); @@ -503,8 +516,10 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { } function test_attestTwice() public { - (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers(nttManagerOther); - nttManagerOther.setThreshold(2); + (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManagerOther + ); + nttManagerOther.setThreshold(TransceiverHelpersLib.SENDING_CHAIN_ID, 2); // register nttManager peer bytes32 peer = toWormholeFormat(address(nttManager)); @@ -532,8 +547,9 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { } function test_attestDisabled() public { - (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers(nttManagerOther); - nttManagerOther.setThreshold(2); + DummyTransceiver e1 = TransceiverHelpersLib.setup_one_transceiver( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManagerOther + ); bytes32 peer = toWormholeFormat(address(nttManager)); nttManagerOther.setPeer(TransceiverHelpersLib.SENDING_CHAIN_ID, peer, 9, type(uint64).max); @@ -553,7 +569,9 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { transceivers ); - nttManagerOther.removeTransceiver(address(e1)); + TransceiverHelpersLib.disableAndRemoveTransceiver( + nttManagerOther, TransceiverHelpersLib.SENDING_CHAIN_ID, address(e1) + ); bytes32 hash = TransceiverStructs.nttManagerMessageDigest(TransceiverHelpersLib.SENDING_CHAIN_ID, m); @@ -666,8 +684,9 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { function test_attestationQuorum() public { address user_B = address(0x456); - (DummyTransceiver e1, DummyTransceiver e2) = - TransceiverHelpersLib.setup_transceivers(nttManagerOther); + (DummyTransceiver e1, DummyTransceiver e2) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManagerOther + ); TrimmedAmount transferAmount = packTrimmedAmount(50, 8); @@ -713,6 +732,13 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { } function test_transfersOnForkedChains() public { + nttManager.enableSendTransceiverForChain( + TransceiverHelpersLib.SENDING_CHAIN_ID, address(dummyTransceiver) + ); + nttManager.enableRecvTransceiverForChain( + TransceiverHelpersLib.SENDING_CHAIN_ID, address(dummyTransceiver) + ); + uint256 evmChainId = block.chainid; address user_A = address(0x123); @@ -927,8 +953,8 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { address user_B = address(0x456); DummyToken token = DummyToken(nttManager.token()); TrimmedAmount transferAmount = packTrimmedAmount(50, 8); - (ITransceiverReceiver e1, ITransceiverReceiver e2) = - TransceiverHelpersLib.setup_transceivers(nttManagerOther); + (ITransceiverReceiver e1, ITransceiverReceiver e2) = TransceiverHelpersLib + .setup_transceivers(TransceiverHelpersLib.SENDING_CHAIN_ID, nttManagerOther); // Step 1 (contract is deployed by setUp()) ITransceiverReceiver[] memory transceivers = new ITransceiverReceiver[](2); @@ -1039,7 +1065,9 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { ); { DummyTransceiver e = new DummyTransceiver(address(newNttManagerNoRateLimiting)); - newNttManagerNoRateLimiting.setTransceiver(address(e)); + TransceiverHelpersLib.setAndEnableTransceiver( + newNttManagerNoRateLimiting, TransceiverHelpersLib.SENDING_CHAIN_ID, address(e) + ); } address user_A = address(0x123); @@ -1060,9 +1088,10 @@ contract TestNttManagerNoRateLimiting is Test, IRateLimiterEvents { vm.stopPrank(); // Check that we can receive a transfer - (DummyTransceiver e1,) = - TransceiverHelpersLib.setup_transceivers(newNttManagerNoRateLimiting); - newNttManagerNoRateLimiting.setThreshold(1); + (DummyTransceiver e1,) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, newNttManagerNoRateLimiting + ); + newNttManagerNoRateLimiting.setThreshold(TransceiverHelpersLib.SENDING_CHAIN_ID, 1); bytes memory transceiverMessage; bytes memory tokenTransferMessage; diff --git a/evm/test/Ownership.t.sol b/evm/test/Ownership.t.sol index d2c0127dd..878c5db6b 100644 --- a/evm/test/Ownership.t.sol +++ b/evm/test/Ownership.t.sol @@ -12,6 +12,7 @@ import {DummyToken} from "./NttManager.t.sol"; contract OwnershipTests is Test { NttManager nttManager; uint16 constant chainId = 7; + uint16 constant chainId2 = 8; function setUp() public { DummyToken t = new DummyToken(); @@ -33,8 +34,6 @@ contract OwnershipTests is Test { // TODO: use setup_transceivers here DummyTransceiver e1 = new DummyTransceiver(address(nttManager)); nttManager.setTransceiver(address(e1)); - nttManager.setThreshold(1); - checkOwnership(e1, nttManager.owner()); } } diff --git a/evm/test/PerChainTransceivers.t.sol b/evm/test/PerChainTransceivers.t.sol new file mode 100644 index 000000000..55ceddd8c --- /dev/null +++ b/evm/test/PerChainTransceivers.t.sol @@ -0,0 +1,920 @@ +// SPDX-License-Identifier: Apache 2 +pragma solidity >=0.8.8 <0.9.0; + +import "forge-std/Test.sol"; +import "forge-std/console.sol"; + +import "../src/NttManager/NttManagerNoRateLimiting.sol"; +import "../src/Transceiver/Transceiver.sol"; +import "../src/interfaces/INttManager.sol"; +import "../src/interfaces/IRateLimiter.sol"; +import "../src/interfaces/ITransceiver.sol"; +import "../src/interfaces/IManagerBase.sol"; +import "../src/interfaces/IRateLimiterEvents.sol"; +import "../src/NttManager/TransceiverRegistry.sol"; +import {Utils} from "./libraries/Utils.sol"; +import {DummyToken, DummyTokenMintAndBurn} from "./NttManager.t.sol"; +import "../src/interfaces/IWormholeTransceiver.sol"; +import {WormholeTransceiver} from "../src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol"; +import "../src/libraries/TransceiverStructs.sol"; +import "./mocks/MockNttManager.sol"; +import "./mocks/MockTransceivers.sol"; + +import "openzeppelin-contracts/contracts/token/ERC20/ERC20.sol"; +import "openzeppelin-contracts/contracts/proxy/ERC1967/ERC1967Proxy.sol"; +import "wormhole-solidity-sdk/interfaces/IWormhole.sol"; +import "wormhole-solidity-sdk/testing/helpers/WormholeSimulator.sol"; +import "wormhole-solidity-sdk/Utils.sol"; +//import "wormhole-solidity-sdk/testing/WormholeRelayerTest.sol"; + +contract TestPerChainTransceivers is Test, IRateLimiterEvents { + MockNttManagerContract nttManagerChain1; + NttManager nttManagerChain2; + NttManager nttManagerChain3; + + using TrimmedAmountLib for uint256; + using TrimmedAmountLib for TrimmedAmount; + + uint16 constant chainId1 = 7; + uint16 constant chainId2 = 100; + uint16 constant chainId3 = 101; + uint8 constant FAST_CONSISTENCY_LEVEL = 200; + uint256 constant GAS_LIMIT = 500000; + + uint16 constant SENDING_CHAIN_ID = 1; + uint256 constant DEVNET_GUARDIAN_PK = + 0xcfb12303a19cde580bb4dd771639b0d26bc68353645571a8cff516ab2ee113a0; + WormholeSimulator guardian; + uint256 initialBlockTimestamp; + + WormholeTransceiver wormholeTransceiverChain1; + WormholeTransceiver secondWormholeTransceiverChain1; + WormholeTransceiver wormholeTransceiverChain2; + WormholeTransceiver secondWormholeTransceiverChain2; + WormholeTransceiver wormholeTransceiverChain3; + WormholeTransceiver secondWormholeTransceiverChain3; + address userA = address(0x123); + address userB = address(0x456); + address userC = address(0x789); + address userD = address(0xABC); + + address relayer = address(0x28D8F1Be96f97C1387e94A53e00eCcFb4E75175a); + IWormhole wormhole = IWormhole(0x4a8bc80Ed5a4067f1CCf107057b8270E0cC11A78); + + // This function sets up the following config: + // - A manager on each of three chains. + // - Two transceivers on each chain, all interconnected as peers. + // - On chain one, it sets a default threshold of one and a per-chain threshold of two for chain three. + // - On chain three, it sets a default threshold of one and a per-chain threshold of two for chain one. + + function setUp() public { + string memory url = "https://ethereum-sepolia-rpc.publicnode.com"; + vm.createSelectFork(url); + initialBlockTimestamp = vm.getBlockTimestamp(); + + guardian = new WormholeSimulator(address(wormhole), DEVNET_GUARDIAN_PK); + + vm.chainId(chainId1); + DummyToken t1 = new DummyToken(); + NttManager implementation = new MockNttManagerContract( + address(t1), IManagerBase.Mode.LOCKING, chainId1, 1 days, false + ); + + nttManagerChain1 = + MockNttManagerContract(address(new ERC1967Proxy(address(implementation), ""))); + nttManagerChain1.initialize(); + + // Create the first transceiver, from chain 1 to chain 2. + WormholeTransceiver wormholeTransceiverChain1Implementation = new MockWormholeTransceiverContract( + address(nttManagerChain1), + address(wormhole), + address(relayer), + address(0x0), + FAST_CONSISTENCY_LEVEL, + GAS_LIMIT + ); + wormholeTransceiverChain1 = MockWormholeTransceiverContract( + address(new ERC1967Proxy(address(wormholeTransceiverChain1Implementation), "")) + ); + + // Only the deployer should be able to initialize + vm.prank(userA); + vm.expectRevert( + abi.encodeWithSelector(ITransceiver.UnexpectedDeployer.selector, address(this), userA) + ); + wormholeTransceiverChain1.initialize(); + + // Actually initialize properly now + wormholeTransceiverChain1.initialize(); + + nttManagerChain1.setTransceiver(address(wormholeTransceiverChain1)); + + // Create the second transceiver for chain 1. + WormholeTransceiver secondWormholeTransceiverChain1Implementation = new MockWormholeTransceiverContract( + address(nttManagerChain1), + address(wormhole), + address(relayer), + address(0x0), + FAST_CONSISTENCY_LEVEL, + GAS_LIMIT + ); + secondWormholeTransceiverChain1 = MockWormholeTransceiverContract( + address(new ERC1967Proxy(address(secondWormholeTransceiverChain1Implementation), "")) + ); + + secondWormholeTransceiverChain1.initialize(); + nttManagerChain1.setTransceiver(address(secondWormholeTransceiverChain1)); + + // Chain 2 setup + vm.chainId(chainId2); + DummyToken t2 = new DummyTokenMintAndBurn(); + NttManager implementationChain2 = new MockNttManagerContract( + address(t2), IManagerBase.Mode.BURNING, chainId2, 1 days, false + ); + + nttManagerChain2 = + MockNttManagerContract(address(new ERC1967Proxy(address(implementationChain2), ""))); + nttManagerChain2.initialize(); + + WormholeTransceiver wormholeTransceiverChain2Implementation = new MockWormholeTransceiverContract( + address(nttManagerChain2), + address(wormhole), + address(relayer), + address(0x0), + FAST_CONSISTENCY_LEVEL, + GAS_LIMIT + ); + wormholeTransceiverChain2 = MockWormholeTransceiverContract( + address(new ERC1967Proxy(address(wormholeTransceiverChain2Implementation), "")) + ); + wormholeTransceiverChain2.initialize(); + + nttManagerChain2.setTransceiver(address(wormholeTransceiverChain2)); + + // Register peer contracts for the nttManager and transceiver. Transceivers and nttManager each have the concept of peers here. + nttManagerChain1.setPeer( + chainId2, bytes32(uint256(uint160(address(nttManagerChain2)))), 9, type(uint64).max + ); + nttManagerChain2.setPeer( + chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))), 7, type(uint64).max + ); + + // Create the second transceiver for chain 2. + WormholeTransceiver secondWormholeTransceiverChain2Implementation = new MockWormholeTransceiverContract( + address(nttManagerChain2), + address(wormhole), + address(relayer), + address(0x0), + FAST_CONSISTENCY_LEVEL, + GAS_LIMIT + ); + secondWormholeTransceiverChain2 = MockWormholeTransceiverContract( + address(new ERC1967Proxy(address(secondWormholeTransceiverChain2Implementation), "")) + ); + + secondWormholeTransceiverChain2.initialize(); + nttManagerChain2.setTransceiver(address(secondWormholeTransceiverChain2)); + + // Set peers for the transceivers + wormholeTransceiverChain1.setWormholePeer( + chainId2, bytes32(uint256(uint160(address(wormholeTransceiverChain2)))) + ); + wormholeTransceiverChain2.setWormholePeer( + chainId1, bytes32(uint256(uint160(address(wormholeTransceiverChain1)))) + ); + secondWormholeTransceiverChain1.setWormholePeer( + chainId2, bytes32(uint256(uint160(address(secondWormholeTransceiverChain2)))) + ); + secondWormholeTransceiverChain2.setWormholePeer( + chainId1, bytes32(uint256(uint160(address(secondWormholeTransceiverChain1)))) + ); + + // Chain 3 setup + vm.chainId(chainId3); + DummyToken t3 = new DummyTokenMintAndBurn(); + NttManager implementationChain3 = new MockNttManagerContract( + address(t3), IManagerBase.Mode.BURNING, chainId3, 1 days, false + ); + + nttManagerChain3 = + MockNttManagerContract(address(new ERC1967Proxy(address(implementationChain3), ""))); + nttManagerChain3.initialize(); + + WormholeTransceiver wormholeTransceiverChain3Implementation = new MockWormholeTransceiverContract( + address(nttManagerChain3), + address(wormhole), + address(relayer), + address(0x0), + FAST_CONSISTENCY_LEVEL, + GAS_LIMIT + ); + wormholeTransceiverChain3 = MockWormholeTransceiverContract( + address(new ERC1967Proxy(address(wormholeTransceiverChain3Implementation), "")) + ); + wormholeTransceiverChain3.initialize(); + + nttManagerChain3.setTransceiver(address(wormholeTransceiverChain3)); + + // Register peer contracts for the nttManager and transceiver. Transceivers and nttManager each have the concept of peers here. + nttManagerChain1.setPeer( + chainId3, bytes32(uint256(uint160(address(nttManagerChain3)))), 9, type(uint64).max + ); + nttManagerChain3.setPeer( + chainId1, bytes32(uint256(uint160(address(nttManagerChain1)))), 7, type(uint64).max + ); + + // Create the second transceiver, from chain 3 to chain 1. + WormholeTransceiver secondWormholeTransceiverChain3Implementation = new MockWormholeTransceiverContract( + address(nttManagerChain3), + address(wormhole), + address(relayer), + address(0x0), + FAST_CONSISTENCY_LEVEL, + GAS_LIMIT + ); + secondWormholeTransceiverChain3 = MockWormholeTransceiverContract( + address(new ERC1967Proxy(address(secondWormholeTransceiverChain3Implementation), "")) + ); + + // Actually initialize properly now + secondWormholeTransceiverChain3.initialize(); + + nttManagerChain3.setTransceiver(address(secondWormholeTransceiverChain3)); + + // Set peers for the transceivers + wormholeTransceiverChain1.setWormholePeer( + chainId3, bytes32(uint256(uint160(address(wormholeTransceiverChain3)))) + ); + wormholeTransceiverChain3.setWormholePeer( + chainId1, bytes32(uint256(uint160(address(wormholeTransceiverChain1)))) + ); + wormholeTransceiverChain3.setWormholePeer( + chainId2, bytes32(uint256(uint160(address(wormholeTransceiverChain2)))) + ); + wormholeTransceiverChain2.setWormholePeer( + chainId3, bytes32(uint256(uint160(address(wormholeTransceiverChain3)))) + ); + secondWormholeTransceiverChain1.setWormholePeer( + chainId3, bytes32(uint256(uint160(address(secondWormholeTransceiverChain3)))) + ); + secondWormholeTransceiverChain3.setWormholePeer( + chainId1, bytes32(uint256(uint160(address(secondWormholeTransceiverChain1)))) + ); + secondWormholeTransceiverChain2.setWormholePeer( + chainId3, bytes32(uint256(uint160(address(secondWormholeTransceiverChain3)))) + ); + secondWormholeTransceiverChain3.setWormholePeer( + chainId2, bytes32(uint256(uint160(address(secondWormholeTransceiverChain2)))) + ); + } + + function test_setUp() public { + address[] memory transceivers = nttManagerChain1.getTransceivers(); + require(transceivers.length == 2, "Invalid number of transceivers"); + require(transceivers[0] == address(wormholeTransceiverChain1), "Transceiver one is invalid"); + require( + transceivers[1] == address(secondWormholeTransceiverChain1), + "Transceiver two is invalid" + ); + + TransceiverRegistry.TransceiverInfo[] memory info = nttManagerChain1.getTransceiverInfo(); + require(info.length == 2, "Invalid number of transceiver infos"); + } + + function test_transceiverSetters() public { + // Make sure nothing is enabled for either sending or receiving. + require( + nttManagerChain1.getChainsEnabledForSending().length == 0, + "There should be no chains enabled for sending to start with" + ); + require( + nttManagerChain1.getChainsEnabledForReceiving().length == 0, + "There should be no chains enabled for receiving to start with" + ); + + // Chain 2 + require( + nttManagerChain1.getEnabledSendTransceiversForChain(chainId2).length == 0, + "There should be nothing enabled for sending on chain two to start with" + ); + require( + nttManagerChain1.getEnabledRecvTransceiversBitmapForChain(chainId2) == 0, + "There should be nothing enabled for receiving on chain two to start with" + ); + + // Chain 3 + require( + nttManagerChain1.getEnabledSendTransceiversForChain(chainId3).length == 0, + "There should be nothing enabled for sending on chain three to start with" + ); + require( + nttManagerChain1.getEnabledRecvTransceiversBitmapForChain(chainId3) == 0, + "There should be nothing enabled for receiving on chain three to start with" + ); + + // Enable a sender on chain two. + nttManagerChain1.enableSendTransceiverForChain(chainId2, address(wormholeTransceiverChain1)); + address[] memory sendTrans = nttManagerChain1.getEnabledSendTransceiversForChain(chainId2); + require(sendTrans.length == 1, "Sending chains length is wrong for chain two #1"); + require( + sendTrans[0] == address(wormholeTransceiverChain1), + "Sending chains is wrong for chain two #1" + ); + uint16[] memory sendChains = nttManagerChain1.getChainsEnabledForSending(); + require(sendChains.length == 1, "There should be one chain enabled for sending"); + require(sendChains[0] == chainId2, "Chain two should be enabled for sending"); + + // Enable a receiver on chain two. + nttManagerChain1.enableRecvTransceiverForChain(chainId2, address(wormholeTransceiverChain1)); + require( + nttManagerChain1.getEnabledRecvTransceiversBitmapForChain(chainId2) == 0x01, + "Receiving bitmap is wrong for chain two #1" + ); + uint16[] memory recvChains = nttManagerChain1.getChainsEnabledForReceiving(); + require(recvChains.length == 1, "There should be one chain enabled for receiving"); + require(recvChains[0] == chainId2, "Chain two should be enabled for receiving #1"); + + // Enable a sender on chain three. + nttManagerChain1.enableSendTransceiverForChain(chainId3, address(wormholeTransceiverChain1)); + sendTrans = nttManagerChain1.getEnabledSendTransceiversForChain(chainId3); + require(sendTrans.length == 1, "Sending chains length is wrong for chain three #1"); + require( + sendTrans[0] == address(wormholeTransceiverChain1), + "Sending chains is wrong for chain three #1" + ); + sendChains = nttManagerChain1.getChainsEnabledForSending(); + require(sendChains.length == 2, "There should be two chain enabled for sending"); + require(sendChains[0] == chainId2, "Chain two should be enabled for sending"); + require(sendChains[1] == chainId3, "Chain three should be enabled for sending"); + + // Enable a receiver on chain three. + nttManagerChain1.enableRecvTransceiverForChain( + chainId3, address(secondWormholeTransceiverChain1) + ); + require( + nttManagerChain1.getEnabledRecvTransceiversBitmapForChain(chainId3) == 0x02, + "Receiving bitmap is wrong for chain three #1" + ); + recvChains = nttManagerChain1.getChainsEnabledForReceiving(); + require(recvChains.length == 2, "There should be two chains enabled for receiving"); + require(recvChains[0] == chainId2, "Chain two should be enabled for receiving #2"); + require(recvChains[1] == chainId3, "Chain three should be enabled for receiving #1"); + require( + nttManagerChain1.getThreshold(chainId3) == 1, "Threshold is wrong for chain three #1" + ); + + // Enable two receivers on chain two. That should not change the threshold. + nttManagerChain1.enableRecvTransceiverForChain( + chainId2, address(secondWormholeTransceiverChain1) + ); + require( + nttManagerChain1.getEnabledRecvTransceiversBitmapForChain(chainId2) == 0x03, + "Receiving bitmap is wrong for chain two #2" + ); + recvChains = nttManagerChain1.getChainsEnabledForReceiving(); + require(recvChains.length == 2, "There should be two chains enabled for receiving"); + require(recvChains[0] == chainId2, "Chain two should be enabled for receiving #3"); + require(recvChains[1] == chainId3, "Chain three should be enabled for receiving #2"); + require(nttManagerChain1.getThreshold(chainId2) == 1, "Threshold is wrong for chain two #2"); + + nttManagerChain1.setThreshold(chainId2, 2); + require(nttManagerChain1.getThreshold(chainId2) == 2, "Threshold is wrong for chain two #3"); + + // Disable one receiver on chain two. This should reduce the threshold. + nttManagerChain1.disableRecvTransceiverForChain( + chainId2, address(wormholeTransceiverChain1) + ); + require( + nttManagerChain1.getEnabledRecvTransceiversBitmapForChain(chainId2) == 0x02, + "Receiving bitmap is wrong for chain two #3" + ); + recvChains = nttManagerChain1.getChainsEnabledForReceiving(); + require(recvChains.length == 2, "There should be two chains enabled for receiving"); + require(recvChains[0] == chainId2, "Chain two should be enabled for receiving #4"); + require(recvChains[1] == chainId3, "Chain three should be enabled for receiving #3"); + require(nttManagerChain1.getThreshold(chainId2) == 1, "Threshold is wrong for chain two #4"); + + // Disable the other receiver on chain two. + nttManagerChain1.disableRecvTransceiverForChain( + chainId2, address(secondWormholeTransceiverChain1) + ); + require( + nttManagerChain1.getEnabledRecvTransceiversBitmapForChain(chainId2) == 0x00, + "Receiving bitmap is wrong for chain two #4" + ); + recvChains = nttManagerChain1.getChainsEnabledForReceiving(); + require(recvChains.length == 1, "There should be only one chain enabled for receiving"); + require(recvChains[0] == chainId3, "Chain three should be enabled for receiving #5"); + require(nttManagerChain1.getThreshold(chainId2) == 0, "Threshold is wrong for chain two #4"); + + // Disable one receiver on chain three. + nttManagerChain1.disableRecvTransceiverForChain( + chainId3, address(secondWormholeTransceiverChain1) + ); + require( + nttManagerChain1.getEnabledRecvTransceiversBitmapForChain(chainId3) == 0x00, + "Receiving bitmap is wrong for chain three #3" + ); + recvChains = nttManagerChain1.getChainsEnabledForReceiving(); + require(recvChains.length == 0, "There should be one chain enabled for receiving"); + require( + nttManagerChain1.getThreshold(chainId3) == 0, "Threshold is wrong for chain three #3" + ); + + // Make sure our senders haven't changed on either chain. + sendTrans = nttManagerChain1.getEnabledSendTransceiversForChain(chainId2); + require(sendTrans.length == 1, "Sending chains length is wrong for chain two #2"); + require( + sendTrans[0] == address(wormholeTransceiverChain1), + "Sending chains is wrong for chain two #2" + ); + + sendTrans = nttManagerChain1.getEnabledSendTransceiversForChain(chainId3); + require(sendTrans.length == 1, "Sending chains length is wrong for chain three #2"); + require( + sendTrans[0] == address(wormholeTransceiverChain1), + "Sending chains is wrong for chain three #2" + ); + + sendChains = nttManagerChain1.getChainsEnabledForSending(); + require(sendChains.length == 2, "There should be two chain enabled for sending #2"); + require(sendChains[0] == chainId2, "Chain two should be enabled for sending #2"); + require(sendChains[1] == chainId3, "Chain three should be enabled for sending #2"); + } + + // This test does a transfer between chain one and chain two. + // Since the receive thresholds are set to one, posting a VAA from only one transceiver completes the transfer. + function test_thresholdLessThanNumReceivers() public { + // On manager one, enable two transceivers for sending and receiving on chain two. Set the threshold to one. + nttManagerChain1.enableSendTransceiverForChain(chainId2, address(wormholeTransceiverChain1)); + nttManagerChain1.enableSendTransceiverForChain( + chainId2, address(secondWormholeTransceiverChain1) + ); + nttManagerChain1.enableRecvTransceiverForChain(chainId2, address(wormholeTransceiverChain1)); + nttManagerChain1.enableRecvTransceiverForChain( + chainId2, address(secondWormholeTransceiverChain1) + ); + nttManagerChain1.setThreshold(chainId2, 1); + + // On manager two, enable two transceivers for sending and receiving on chain two. Set the threshold to one. + nttManagerChain2.enableSendTransceiverForChain(chainId1, address(wormholeTransceiverChain2)); + nttManagerChain2.enableSendTransceiverForChain( + chainId1, address(secondWormholeTransceiverChain2) + ); + nttManagerChain2.enableRecvTransceiverForChain(chainId1, address(wormholeTransceiverChain2)); + nttManagerChain2.enableRecvTransceiverForChain( + chainId1, address(secondWormholeTransceiverChain2) + ); + nttManagerChain2.setThreshold(chainId1, 1); + + vm.chainId(chainId1); + + // Setting up the transfer + DummyToken token1 = DummyToken(nttManagerChain1.token()); + DummyToken token2 = DummyTokenMintAndBurn(nttManagerChain2.token()); + + uint8 decimals = token1.decimals(); + uint256 sendingAmount = 5 * 10 ** decimals; + token1.mintDummy(address(userA), 5 * 10 ** decimals); + + // Transfer tokens from chain one to chain two through standard means (not relayer) + vm.startPrank(userA); + token1.approve(address(nttManagerChain1), sendingAmount); + vm.recordLogs(); + { + uint256 nttManagerBalanceBefore = token1.balanceOf(address(nttManagerChain1)); + uint256 userBalanceBefore = token1.balanceOf(address(userA)); + nttManagerChain1.transfer(sendingAmount, chainId2, bytes32(uint256(uint160(userB)))); + + // Balance check on funds going in and out working as expected + uint256 nttManagerBalanceAfter = token1.balanceOf(address(nttManagerChain1)); + uint256 userBalanceAfter = token1.balanceOf(address(userB)); + require( + nttManagerBalanceBefore + sendingAmount == nttManagerBalanceAfter, + "Should be locking the tokens" + ); + require( + userBalanceBefore - sendingAmount == userBalanceAfter, + "User should have sent tokens" + ); + } + + vm.stopPrank(); + + // Get and sign the log to go down the other pipes. There should be two messages since we have two transceivers. + Vm.Log[] memory entries = guardian.fetchWormholeMessageFromLog(vm.getRecordedLogs()); + require(2 == entries.length, "Unexpected number of log entries 1"); + bytes[] memory encodedVMs = new bytes[](entries.length); + for (uint256 i = 0; i < encodedVMs.length; i++) { + encodedVMs[i] = guardian.fetchSignedMessageFromLogs(entries[i], chainId1); + } + + // Chain2 verification and checks + vm.chainId(chainId2); + + uint256 supplyBefore = token2.totalSupply(); + wormholeTransceiverChain2.receiveMessage(encodedVMs[0]); + uint256 supplyAfter = token2.totalSupply(); + + require(sendingAmount + supplyBefore == supplyAfter, "Supplies dont match #1"); + require(token2.balanceOf(userB) == sendingAmount, "User didn't receive tokens"); + require(token2.balanceOf(address(nttManagerChain2)) == 0, "NttManager has unintended funds"); + + // Go back the other way from a THIRD user + vm.prank(userB); + token2.transfer(userC, sendingAmount); + + vm.startPrank(userC); + token2.approve(address(nttManagerChain2), sendingAmount); + vm.recordLogs(); + + // Supply checks on the transfer + supplyBefore = token2.totalSupply(); + nttManagerChain2.transfer( + sendingAmount, + chainId1, + toWormholeFormat(userD), + toWormholeFormat(userC), + false, + encodeTransceiverInstruction(true) + ); + + supplyAfter = token2.totalSupply(); + + require(sendingAmount - supplyBefore == supplyAfter, "Supplies don't match"); + require(token2.balanceOf(userB) == 0, "OG user receive tokens"); + require(token2.balanceOf(userC) == 0, "Sending user didn't receive tokens"); + require( + token2.balanceOf(address(nttManagerChain2)) == 0, + "NttManager didn't receive unintended funds" + ); + + // Get and sign the log to go down the other pipe. Thank you to whoever wrote this code in the past! + entries = guardian.fetchWormholeMessageFromLog(vm.getRecordedLogs()); + require(2 == entries.length, "Unexpected number of log entries 2"); + encodedVMs = new bytes[](entries.length); + for (uint256 i = 0; i < encodedVMs.length; i++) { + encodedVMs[i] = guardian.fetchSignedMessageFromLogs(entries[i], chainId2); + } + + // Chain1 verification and checks with the receiving of the message + vm.chainId(chainId1); + + supplyBefore = token1.totalSupply(); + wormholeTransceiverChain1.receiveMessage(encodedVMs[0]); + supplyAfter = token1.totalSupply(); + + require(supplyBefore == supplyAfter, "Supplies don't match between operations"); + require(token1.balanceOf(userB) == 0, "OG user receive tokens"); + require(token1.balanceOf(userC) == 0, "Sending user didn't receive tokens"); + require(token1.balanceOf(userD) == sendingAmount, "Transfer did not complete"); + + // Submitting the second message back on chain one should not change anything. + supplyBefore = token1.totalSupply(); + secondWormholeTransceiverChain1.receiveMessage(encodedVMs[1]); + supplyAfter = token1.totalSupply(); + + require(supplyBefore == supplyAfter, "Supplies don't match between operations"); + require(token1.balanceOf(userB) == 0, "OG user receive tokens"); + require(token1.balanceOf(userC) == 0, "Sending user didn't receive tokens"); + require( + token1.balanceOf(userD) == sendingAmount, + "Second message updated the balance when it shouldn't have" + ); + } + + function test_someReverts() public { + // Threshold too high (since nothing has been enabled yet). + vm.expectRevert( + abi.encodeWithSelector(TransceiverRegistry.ThresholdTooHigh.selector, chainId2, 1, 0) + ); + nttManagerChain1.setThreshold(chainId2, 1); + + // Can't set the threshold to zero. + nttManagerChain1.enableRecvTransceiverForChain(chainId2, address(wormholeTransceiverChain1)); + vm.expectRevert( + abi.encodeWithSelector(TransceiverRegistry.ZeroThreshold.selector, chainId2) + ); + nttManagerChain1.setThreshold(chainId2, 0); + + // Threshold too high when something is actually enabled. + vm.expectRevert( + abi.encodeWithSelector(TransceiverRegistry.ThresholdTooHigh.selector, chainId2, 2, 1) + ); + nttManagerChain1.setThreshold(chainId2, 2); + + // Can't enable sending for chain zero. + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, 0)); + nttManagerChain1.getEnabledSendTransceiversForChain(0); + + // Can't enable receiving for chain zero. + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, 0)); + nttManagerChain1.getEnabledRecvTransceiversBitmapForChain(0); + + // Can't add transceiver with zero address. + vm.expectRevert( + abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector) + ); + nttManagerChain1.setTransceiver(address(0)); + + // Can't add the same transceiver twice. + vm.expectRevert( + abi.encodeWithSelector( + TransceiverRegistry.TransceiverAlreadyEnabled.selector, + address(wormholeTransceiverChain1) + ) + ); + nttManagerChain1.setTransceiver(address(wormholeTransceiverChain1)); + + // Can't remove transceiver with zero address. + vm.expectRevert( + abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector) + ); + nttManagerChain1.removeTransceiver(address(0)); + + // Can't remove a transceiver that hasn't been added. + vm.expectRevert( + abi.encodeWithSelector( + TransceiverRegistry.NonRegisteredTransceiver.selector, + address(wormholeTransceiverChain2) + ) + ); + nttManagerChain1.removeTransceiver(address(wormholeTransceiverChain2)); + + // Can't remove a transceiver that has already been removed. + nttManagerChain1.disableRecvTransceiverForChain( + chainId2, address(wormholeTransceiverChain1) + ); + + nttManagerChain1.removeTransceiver(address(wormholeTransceiverChain1)); + vm.expectRevert( + abi.encodeWithSelector( + TransceiverRegistry.DisabledTransceiver.selector, address(wormholeTransceiverChain1) + ) + ); + nttManagerChain1.removeTransceiver(address(wormholeTransceiverChain1)); + + nttManagerChain1.setTransceiver(address(wormholeTransceiverChain1)); + + // Can't disable a send transcevier that's not enabled. + vm.expectRevert( + abi.encodeWithSelector( + TransceiverRegistry.TransceiverAlreadyDisabled.selector, + address(wormholeTransceiverChain1) + ) + ); + nttManagerChain1.disableSendTransceiverForChain( + chainId2, address(wormholeTransceiverChain1) + ); + + // Can't enable a send transcevier that is already enabled. + nttManagerChain1.enableSendTransceiverForChain(chainId2, address(wormholeTransceiverChain1)); + vm.expectRevert( + abi.encodeWithSelector( + TransceiverRegistry.TransceiverAlreadyEnabled.selector, + address(wormholeTransceiverChain1) + ) + ); + nttManagerChain1.enableSendTransceiverForChain(chainId2, address(wormholeTransceiverChain1)); + + // Can't disable a receive transcevier that's not enabled. + vm.expectRevert( + abi.encodeWithSelector( + TransceiverRegistry.TransceiverAlreadyDisabled.selector, + address(secondWormholeTransceiverChain1) + ) + ); + nttManagerChain1.disableRecvTransceiverForChain( + chainId2, address(secondWormholeTransceiverChain1) + ); + + // Can't enable a receive transceiver that's already enabled. + nttManagerChain1.enableRecvTransceiverForChain(chainId3, address(wormholeTransceiverChain1)); + vm.expectRevert( + abi.encodeWithSelector( + TransceiverRegistry.TransceiverAlreadyEnabled.selector, + address(wormholeTransceiverChain1) + ) + ); + nttManagerChain1.enableRecvTransceiverForChain(chainId3, address(wormholeTransceiverChain1)); + + // Can't enable transceiver with zero address. + vm.expectRevert( + abi.encodeWithSelector(TransceiverRegistry.InvalidTransceiverZeroAddress.selector) + ); + nttManagerChain1.enableRecvTransceiverForChain(chainId3, address(0)); + + // Can't enable transceiver with chain zero. + vm.expectRevert(abi.encodeWithSelector(TransceiverRegistry.InvalidChain.selector, 0)); + nttManagerChain1.enableRecvTransceiverForChain(0, address(wormholeTransceiverChain1)); + } + + // This test does a transfer between chain one and chain three. + // Since the threshold for these two chains is two, the transfer is not completed until both VAAs are posted. + function test_thresholdEqualToNumberOfReceivers() public { + // On manager one, enable two transceivers for sending and receiving on chain two. Set the threshold to one. + nttManagerChain1.enableSendTransceiverForChain(chainId3, address(wormholeTransceiverChain1)); + nttManagerChain1.enableSendTransceiverForChain( + chainId3, address(secondWormholeTransceiverChain1) + ); + nttManagerChain1.enableRecvTransceiverForChain(chainId3, address(wormholeTransceiverChain1)); + nttManagerChain1.enableRecvTransceiverForChain( + chainId3, address(secondWormholeTransceiverChain1) + ); + nttManagerChain1.setThreshold(chainId3, 2); + + // On manager two, enable two transceivers for sending and receiving on chain two. Set the threshold to one. + nttManagerChain3.enableSendTransceiverForChain(chainId1, address(wormholeTransceiverChain3)); + nttManagerChain3.enableSendTransceiverForChain( + chainId1, address(secondWormholeTransceiverChain3) + ); + nttManagerChain3.enableRecvTransceiverForChain(chainId1, address(wormholeTransceiverChain3)); + nttManagerChain3.enableRecvTransceiverForChain( + chainId1, address(secondWormholeTransceiverChain3) + ); + nttManagerChain3.setThreshold(chainId1, 2); + + vm.chainId(chainId1); + + // Setting up the transfer + DummyToken token1 = DummyToken(nttManagerChain1.token()); + DummyToken token3 = DummyTokenMintAndBurn(nttManagerChain3.token()); + + uint8 decimals = token1.decimals(); + uint256 sendingAmount = 5 * 10 ** decimals; + token1.mintDummy(address(userA), 5 * 10 ** decimals); + vm.startPrank(userA); + token1.approve(address(nttManagerChain1), sendingAmount); + + vm.recordLogs(); + + // Send token from chain 1 to chain 3, userB. + { + uint256 nttManagerBalanceBefore = token1.balanceOf(address(nttManagerChain1)); + uint256 userBalanceBefore = token1.balanceOf(address(userA)); + nttManagerChain1.transfer(sendingAmount, chainId3, bytes32(uint256(uint160(userB)))); + + // Balance check on funds going in and out working as expected + uint256 nttManagerBalanceAfter = token1.balanceOf(address(nttManagerChain1)); + uint256 userBalanceAfter = token1.balanceOf(address(userB)); + require( + nttManagerBalanceBefore + sendingAmount == nttManagerBalanceAfter, + "Should be locking the tokens" + ); + require( + userBalanceBefore - sendingAmount == userBalanceAfter, + "User should have sent tokens" + ); + } + + vm.stopPrank(); + + // Get and sign the log to go down the other pipes. There should be two messages since we have two transceivers. + Vm.Log[] memory entries = guardian.fetchWormholeMessageFromLog(vm.getRecordedLogs()); + require(2 == entries.length, "Unexpected number of log entries 3"); + bytes[] memory encodedVMs = new bytes[](entries.length); + for (uint256 i = 0; i < encodedVMs.length; i++) { + encodedVMs[i] = guardian.fetchSignedMessageFromLogs(entries[i], chainId1); + } + + // Chain3 verification and checks + vm.chainId(chainId3); + + uint256 supplyBefore = token3.totalSupply(); + + // Submit the first message on chain 3. The numbers shouldn't change yet since the threshold is two. + wormholeTransceiverChain3.receiveMessage(encodedVMs[0]); + uint256 supplyAfter = token3.totalSupply(); + + require(supplyBefore == supplyAfter, "Supplies changed early"); + require(token3.balanceOf(userB) == 0, "User receive tokens early"); + require(token3.balanceOf(address(nttManagerChain3)) == 0, "NttManager has unintended funds"); + + // Submit the second message and the transfer should complete. + secondWormholeTransceiverChain3.receiveMessage(encodedVMs[1]); + supplyAfter = token3.totalSupply(); + + require(sendingAmount + supplyBefore == supplyAfter, "Supplies dont match"); + require(token3.balanceOf(userB) == sendingAmount, "User didn't receive tokens"); + require(token3.balanceOf(address(nttManagerChain3)) == 0, "NttManager has unintended funds"); + + // Go back the other way from a THIRD user + vm.prank(userB); + token3.transfer(userC, sendingAmount); + + vm.startPrank(userC); + token3.approve(address(nttManagerChain3), sendingAmount); + vm.recordLogs(); + + // Supply checks on the transfer + supplyBefore = token3.totalSupply(); + nttManagerChain3.transfer( + sendingAmount, + chainId1, + toWormholeFormat(userD), + toWormholeFormat(userC), + false, + encodeTransceiverInstruction(true) + ); + + supplyAfter = token3.totalSupply(); + + require(sendingAmount - supplyBefore == supplyAfter, "Supplies don't match"); + require(token3.balanceOf(userB) == 0, "OG user receive tokens"); + require(token3.balanceOf(userC) == 0, "Sending user didn't receive tokens"); + require( + token3.balanceOf(address(nttManagerChain3)) == 0, + "NttManager didn't receive unintended funds" + ); + + // Get and sign the log to go down the other pipe. Thank you to whoever wrote this code in the past! + entries = guardian.fetchWormholeMessageFromLog(vm.getRecordedLogs()); + require(2 == entries.length, "Unexpected number of log entries for response"); + encodedVMs = new bytes[](entries.length); + for (uint256 i = 0; i < encodedVMs.length; i++) { + encodedVMs[i] = guardian.fetchSignedMessageFromLogs(entries[i], chainId3); + } + + // Chain1 verification and checks with the receiving of the message + vm.chainId(chainId1); + + // Submit the first message back on chain one. Nothing should happen because our threshold is two. + supplyBefore = token1.totalSupply(); + wormholeTransceiverChain1.receiveMessage(encodedVMs[0]); + supplyAfter = token1.totalSupply(); + + require(supplyBefore == supplyAfter, "Supplies don't match between operations"); + require(token1.balanceOf(userB) == 0, "OG user receive tokens"); + require(token1.balanceOf(userC) == 0, "Sending user didn't receive tokens"); + require(token1.balanceOf(userD) == 0, "User received funds before they should"); + + // Submit the second message back on chain one. This should update the balance. + supplyBefore = token1.totalSupply(); + secondWormholeTransceiverChain1.receiveMessage(encodedVMs[1]); + supplyAfter = token1.totalSupply(); + + require(supplyBefore == supplyAfter, "Supplies don't match between operations"); + require(token1.balanceOf(userB) == 0, "OG user receive tokens"); + require(token1.balanceOf(userC) == 0, "Sending user didn't receive tokens"); + require(token1.balanceOf(userD) == sendingAmount, "User received funds"); + } + + function test_onlyTransceiverCanCallAttestationReceived() public { + // On manager one, enable two transceivers for sending and receiving on chain two. Set the threshold to one. + nttManagerChain1.enableSendTransceiverForChain(chainId2, address(wormholeTransceiverChain1)); + nttManagerChain1.enableRecvTransceiverForChain(chainId2, address(wormholeTransceiverChain1)); + nttManagerChain1.setThreshold(chainId2, 1); + + TransceiverStructs.NttManagerMessage memory payload; + + vm.startPrank(userA); + + vm.expectRevert( + abi.encodeWithSelector( + TransceiverRegistry.CallerNotTransceiver.selector, address(userA) + ) + ); + nttManagerChain1.attestationReceived( + chainId2, toWormholeFormat(address(nttManagerChain2)), payload + ); + } + + function encodeTransceiverInstruction( + bool relayer_off + ) public view returns (bytes memory) { + WormholeTransceiver.WormholeTransceiverInstruction memory instruction = + IWormholeTransceiver.WormholeTransceiverInstruction(relayer_off); + bytes memory encodedInstructionWormhole = + wormholeTransceiverChain1.encodeWormholeTransceiverInstruction(instruction); + TransceiverStructs.TransceiverInstruction memory TransceiverInstruction = TransceiverStructs + .TransceiverInstruction({index: 0, payload: encodedInstructionWormhole}); + TransceiverStructs.TransceiverInstruction[] memory TransceiverInstructions = + new TransceiverStructs.TransceiverInstruction[](1); + TransceiverInstructions[0] = TransceiverInstruction; + return TransceiverStructs.encodeTransceiverInstructions(TransceiverInstructions); + } + + // Encode an instruction for each of the relayers + function encodeTransceiverInstructions( + bool relayer_off + ) public view returns (bytes memory) { + WormholeTransceiver.WormholeTransceiverInstruction memory instruction = + IWormholeTransceiver.WormholeTransceiverInstruction(relayer_off); + + bytes memory encodedInstructionWormhole = + wormholeTransceiverChain1.encodeWormholeTransceiverInstruction(instruction); + + TransceiverStructs.TransceiverInstruction memory TransceiverInstruction1 = + TransceiverStructs.TransceiverInstruction({index: 0, payload: encodedInstructionWormhole}); + TransceiverStructs.TransceiverInstruction memory TransceiverInstruction2 = + TransceiverStructs.TransceiverInstruction({index: 1, payload: encodedInstructionWormhole}); + + TransceiverStructs.TransceiverInstruction[] memory TransceiverInstructions = + new TransceiverStructs.TransceiverInstruction[](2); + + TransceiverInstructions[0] = TransceiverInstruction1; + TransceiverInstructions[1] = TransceiverInstruction2; + + return TransceiverStructs.encodeTransceiverInstructions(TransceiverInstructions); + } +} diff --git a/evm/test/RateLimit.t.sol b/evm/test/RateLimit.t.sol index 159246e20..3277ac6f0 100644 --- a/evm/test/RateLimit.t.sol +++ b/evm/test/RateLimit.t.sol @@ -49,7 +49,7 @@ contract TestRateLimit is Test, IRateLimiterEvents { nttManager.setPeer(chainId2, toWormholeFormat(address(0x1)), 9, type(uint64).max); DummyTransceiver e = new DummyTransceiver(address(nttManager)); - nttManager.setTransceiver(address(e)); + TransceiverHelpersLib.setAndEnableTransceiver(nttManager, chainId2, address(e)); } function test_outboundRateLimit_setLimitSimple() public { @@ -525,8 +525,9 @@ contract TestRateLimit is Test, IRateLimiterEvents { function test_inboundRateLimit_simple() public { address user_B = address(0x456); - (DummyTransceiver e1, DummyTransceiver e2) = - TransceiverHelpersLib.setup_transceivers(nttManager); + (DummyTransceiver e1, DummyTransceiver e2) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManager + ); DummyToken token = DummyToken(nttManager.token()); @@ -565,8 +566,9 @@ contract TestRateLimit is Test, IRateLimiterEvents { function test_inboundRateLimit_queue() public { address user_B = address(0x456); - (DummyTransceiver e1, DummyTransceiver e2) = - TransceiverHelpersLib.setup_transceivers(nttManager); + (DummyTransceiver e1, DummyTransceiver e2) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManager + ); DummyToken token = DummyToken(nttManager.token()); @@ -717,8 +719,9 @@ contract TestRateLimit is Test, IRateLimiterEvents { vm.warp(receiveTime); // now receive 10 tokens from user_B -> user_A - (DummyTransceiver e1, DummyTransceiver e2) = - TransceiverHelpersLib.setup_transceivers(nttManager); + (DummyTransceiver e1, DummyTransceiver e2) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManager + ); ITransceiverReceiver[] memory transceivers = new ITransceiverReceiver[](2); transceivers[0] = e1; @@ -811,8 +814,9 @@ contract TestRateLimit is Test, IRateLimiterEvents { } function initializeTransceivers() public returns (ITransceiverReceiver[] memory) { - (DummyTransceiver e1, DummyTransceiver e2) = - TransceiverHelpersLib.setup_transceivers(nttManager); + (DummyTransceiver e1, DummyTransceiver e2) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManager + ); ITransceiverReceiver[] memory transceivers = new ITransceiverReceiver[](2); transceivers[0] = e1; @@ -1068,8 +1072,9 @@ contract TestRateLimit is Test, IRateLimiterEvents { address user_B = address(0x456); - (DummyTransceiver e1, DummyTransceiver e2) = - TransceiverHelpersLib.setup_transceivers(nttManager); + (DummyTransceiver e1, DummyTransceiver e2) = TransceiverHelpersLib.setup_transceivers( + TransceiverHelpersLib.SENDING_CHAIN_ID, nttManager + ); DummyToken token = DummyToken(nttManager.token()); diff --git a/evm/test/Upgrades.t.sol b/evm/test/Upgrades.t.sol index 6819ee279..abadeb781 100644 --- a/evm/test/Upgrades.t.sol +++ b/evm/test/Upgrades.t.sol @@ -16,6 +16,7 @@ import {Utils} from "./libraries/Utils.sol"; import {DummyToken, DummyTokenMintAndBurn} from "./NttManager.t.sol"; import {WormholeTransceiver} from "../src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol"; import "../src/libraries/TransceiverStructs.sol"; +import "./libraries/TransceiverHelpers.sol"; import "./mocks/MockNttManager.sol"; import "./mocks/MockTransceivers.sol"; @@ -83,7 +84,9 @@ contract TestUpgrades is Test, IRateLimiterEvents { ); wormholeTransceiverChain1.initialize(); - nttManagerChain1.setTransceiver(address(wormholeTransceiverChain1)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain1, chainId2, address(wormholeTransceiverChain1) + ); nttManagerChain1.setOutboundLimit(type(uint64).max); nttManagerChain1.setInboundLimit(type(uint64).max, chainId2); @@ -111,7 +114,9 @@ contract TestUpgrades is Test, IRateLimiterEvents { ); wormholeTransceiverChain2.initialize(); - nttManagerChain2.setTransceiver(address(wormholeTransceiverChain2)); + TransceiverHelpersLib.setAndEnableTransceiver( + nttManagerChain2, chainId1, address(wormholeTransceiverChain2) + ); nttManagerChain2.setOutboundLimit(type(uint64).max); nttManagerChain2.setInboundLimit(type(uint64).max, chainId1); @@ -136,8 +141,8 @@ contract TestUpgrades is Test, IRateLimiterEvents { chainId1, bytes32(uint256(uint160(address(wormholeTransceiverChain1)))) ); - nttManagerChain1.setThreshold(1); - nttManagerChain2.setThreshold(1); + nttManagerChain1.setThreshold(chainId2, 1); + nttManagerChain2.setThreshold(chainId1, 1); vm.chainId(chainId1); } diff --git a/evm/test/libraries/TransceiverHelpers.sol b/evm/test/libraries/TransceiverHelpers.sol index 7b7c39a4d..6c50c5f45 100644 --- a/evm/test/libraries/TransceiverHelpers.sol +++ b/evm/test/libraries/TransceiverHelpers.sol @@ -22,10 +22,54 @@ library TransceiverHelpersLib { DummyTransceiver e2 = new DummyTransceiver(address(nttManager)); nttManager.setTransceiver(address(e1)); nttManager.setTransceiver(address(e2)); - nttManager.setThreshold(2); + nttManager.setThreshold(SENDING_CHAIN_ID, 2); return (e1, e2); } + function setup_transceivers( + uint16 chainId, + NttManager nttManager + ) internal returns (DummyTransceiver, DummyTransceiver) { + DummyTransceiver e1 = new DummyTransceiver(address(nttManager)); + DummyTransceiver e2 = new DummyTransceiver(address(nttManager)); + nttManager.setTransceiver(address(e1)); + nttManager.enableSendTransceiverForChain(chainId, address(e1)); + nttManager.enableRecvTransceiverForChain(chainId, address(e1)); + nttManager.setTransceiver(address(e2)); + nttManager.enableSendTransceiverForChain(chainId, address(e2)); + nttManager.enableRecvTransceiverForChain(chainId, address(e2)); + nttManager.setThreshold(chainId, 2); + return (e1, e2); + } + + function setup_one_transceiver( + uint16 chainId, + NttManager nttManager + ) internal returns (DummyTransceiver) { + DummyTransceiver e1 = new DummyTransceiver(address(nttManager)); + nttManager.setTransceiver(address(e1)); + nttManager.enableSendTransceiverForChain(chainId, address(e1)); + nttManager.enableRecvTransceiverForChain(chainId, address(e1)); + nttManager.setThreshold(chainId, 1); + return e1; + } + + function setAndEnableTransceiver(NttManager nttMgr, uint16 chnId, address transceiver) public { + nttMgr.setTransceiver(address(transceiver)); + nttMgr.enableSendTransceiverForChain(chnId, address(transceiver)); + nttMgr.enableRecvTransceiverForChain(chnId, address(transceiver)); + } + + function disableAndRemoveTransceiver( + NttManager nttMgr, + uint16 chnId, + address transceiver + ) public { + nttMgr.disableSendTransceiverForChain(chnId, address(transceiver)); + nttMgr.disableRecvTransceiverForChain(chnId, address(transceiver)); + nttMgr.removeTransceiver(address(transceiver)); + } + function attestTransceiversHelper( address to, bytes32 id, diff --git a/evm/test/mocks/MockNttManager.sol b/evm/test/mocks/MockNttManager.sol index 6c3fd09fe..f1a9e7760 100644 --- a/evm/test/mocks/MockNttManager.sol +++ b/evm/test/mocks/MockNttManager.sol @@ -58,7 +58,6 @@ contract MockNttManagerMigrateBasic is NttManager { ) NttManager(token, mode, chainId, rateLimitDuration, skipRateLimiting) {} function _migrate() internal view override { - _checkThresholdInvariants(); _checkTransceiversInvariants(); revert("Proper migrate called"); }