@@ -139,15 +139,19 @@ abstract contract ManagerBase is
139
139
bytes32 nttManagerMessageHash =
140
140
TransceiverStructs.nttManagerMessageDigest (sourceChainId, payload);
141
141
142
+ // The `msg.sender` is the transceiver. Get the index for it.
143
+ uint8 index = _getTransceiverInfosStorage ()[msg .sender ].index;
144
+
145
+ // TODO: Is there a race condition with disabling a transceiver while a tx is outstanding?
146
+ if (! _isRecvTransceiverEnabledForChain (sourceChainId, index)) {
147
+ revert CallerNotTransceiver (msg .sender );
148
+ }
149
+
142
150
// set the attested flag for this transceiver.
143
151
// NOTE: Attestation is idempotent (bitwise or 1), but we revert
144
152
// anyway to ensure that the client does not continue to initiate calls
145
153
// to receive the same message through the same transceiver.
146
- if (
147
- transceiverAttestedToMessage (
148
- nttManagerMessageHash, _getTransceiverInfosStorage ()[msg .sender ].index
149
- )
150
- ) {
154
+ if (transceiverAttestedToMessage (nttManagerMessageHash, index)) {
151
155
revert TransceiverAlreadyAttestedToMessage (nttManagerMessageHash);
152
156
}
153
157
_setTransceiverAttestedToMessage (nttManagerMessageHash, msg .sender );
@@ -162,7 +166,7 @@ abstract contract ManagerBase is
162
166
) internal returns (bytes32 , bool ) {
163
167
bytes32 digest = TransceiverStructs.nttManagerMessageDigest (sourceChainId, message);
164
168
165
- if (! isMessageApproved (digest)) {
169
+ if (! isMessageApproved (sourceChainId, digest)) {
166
170
revert MessageNotApproved (digest);
167
171
}
168
172
@@ -225,7 +229,7 @@ abstract contract ManagerBase is
225
229
)
226
230
{
227
231
// cache enabled transceivers to avoid multiple storage reads
228
- address [] memory enabledTransceivers = _getEnabledTransceiversStorage ( );
232
+ address [] memory enabledTransceivers = getEnabledSendTransceiversForChain (recipientChain );
229
233
230
234
TransceiverStructs.TransceiverInstruction[] memory instructions;
231
235
@@ -280,15 +284,16 @@ abstract contract ManagerBase is
280
284
}
281
285
282
286
/// @inheritdoc IManagerBase
283
- function getThreshold () public view returns (uint8 ) {
284
- return _getThresholdStorage ().num;
287
+ /// @dev This is here because it is defined in IManagerBase.
288
+ function getThreshold (
289
+ uint16 sourceChainId
290
+ ) public view returns (uint8 ) {
291
+ return _getPerChainRecvTransceiverDataStorage ()[sourceChainId].threshold;
285
292
}
286
293
287
294
/// @inheritdoc IManagerBase
288
- function isMessageApproved (
289
- bytes32 digest
290
- ) public view returns (bool ) {
291
- uint8 threshold = getThreshold ();
295
+ function isMessageApproved (uint16 sourceChainId , bytes32 digest ) public view returns (bool ) {
296
+ uint8 threshold = getThreshold (sourceChainId);
292
297
return messageAttestations (digest) >= threshold && threshold > 0 ;
293
298
}
294
299
@@ -397,20 +402,9 @@ abstract contract ManagerBase is
397
402
}
398
403
399
404
/// @inheritdoc IManagerBase
400
- function setThreshold (
401
- uint8 threshold
402
- ) external onlyOwner {
403
- if (threshold == 0 ) {
404
- revert ZeroThreshold ();
405
- }
406
-
407
- _Threshold storage _threshold = _getThresholdStorage ();
408
- uint8 oldThreshold = _threshold.num;
409
-
410
- _threshold.num = threshold;
411
- _checkThresholdInvariants ();
412
-
413
- emit ThresholdChanged (oldThreshold, threshold);
405
+ /// @dev This is here because it is defined in IManagerBase.
406
+ function setThreshold (uint16 sourceChainId , uint8 threshold ) external onlyOwner {
407
+ _setThreshold (sourceChainId, threshold);
414
408
}
415
409
416
410
// =============== Internal ==============================================================
@@ -480,20 +474,4 @@ abstract contract ManagerBase is
480
474
);
481
475
}
482
476
}
483
-
484
- function _checkThresholdInvariants () internal view {
485
- uint8 threshold = _getThresholdStorage ().num;
486
- _NumTransceivers memory numTransceivers = _getNumTransceiversStorage ();
487
-
488
- // invariant: threshold <= enabledTransceivers.length
489
- if (threshold > numTransceivers.enabled) {
490
- revert ThresholdTooHigh (threshold, numTransceivers.enabled);
491
- }
492
-
493
- if (numTransceivers.registered > 0 ) {
494
- if (threshold == 0 ) {
495
- revert ZeroThreshold ();
496
- }
497
- }
498
- }
499
477
}
0 commit comments