Skip to content

Commit f48157d

Browse files
committed
feat: lock per user implementation
1 parent 766fa70 commit f48157d

File tree

1 file changed

+56
-21
lines changed

1 file changed

+56
-21
lines changed

crates/batcher/src/lib.rs

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ pub struct Batcher {
8686
service_manager: ServiceManager,
8787
service_manager_fallback: ServiceManager,
8888
batch_state: Mutex<BatchState>,
89+
user_mutexes: Mutex<HashMap<Address, Arc<Mutex<()>>>>,
8990
min_block_interval: u64,
9091
transaction_wait_timeout: u64,
9192
max_proof_size: usize,
@@ -276,6 +277,7 @@ impl Batcher {
276277
aggregator_gas_cost: config.batcher.aggregator_gas_cost,
277278
posting_batch: Mutex::new(false),
278279
batch_state: Mutex::new(batch_state),
280+
user_mutexes: Mutex::new(HashMap::new()),
279281
disabled_verifiers: Mutex::new(disabled_verifiers),
280282
metrics,
281283
telemetry,
@@ -679,6 +681,7 @@ impl Batcher {
679681
// If it was not present, then the user nonce is queried to the Aligned contract.
680682
// Lastly, we get a lock of the batch state again and insert the user state if it was still missing.
681683

684+
// Step 1: Get or insert the per-address mutex under lock
682685
let is_user_in_state: bool = {
683686
let batch_state_lock = self.batch_state.lock().await;
684687
batch_state_lock.user_states.contains_key(&addr)
@@ -727,13 +730,22 @@ impl Batcher {
727730
// This is needed because we need to query the user state to make validations and
728731
// finally add the proof to the batch queue.
729732

730-
let mut batch_state_lock = self.batch_state.lock().await;
733+
let user_mutex = {
734+
let mut map = self.user_mutexes.lock().await;
735+
map.entry(addr)
736+
.or_insert_with(|| Arc::new(Mutex::new(())))
737+
.clone()
738+
};
739+
let _ = user_mutex.lock().await;
731740

732741
let msg_max_fee = nonced_verification_data.max_fee;
733-
let Some(user_last_max_fee_limit) =
734-
batch_state_lock.get_user_last_max_fee_limit(&addr).await
742+
let Some(user_last_max_fee_limit) = self
743+
.batch_state
744+
.lock()
745+
.await
746+
.get_user_last_max_fee_limit(&addr)
747+
.await
735748
else {
736-
std::mem::drop(batch_state_lock);
737749
send_message(
738750
ws_conn_sink.clone(),
739751
SubmitProofResponseMessage::AddToBatchError,
@@ -743,9 +755,13 @@ impl Batcher {
743755
return Ok(());
744756
};
745757

746-
let Some(user_accumulated_fee) = batch_state_lock.get_user_total_fees_in_queue(&addr).await
758+
let Some(user_accumulated_fee) = self
759+
.batch_state
760+
.lock()
761+
.await
762+
.get_user_total_fees_in_queue(&addr)
763+
.await
747764
else {
748-
std::mem::drop(batch_state_lock);
749765
send_message(
750766
ws_conn_sink.clone(),
751767
SubmitProofResponseMessage::AddToBatchError,
@@ -756,7 +772,6 @@ impl Batcher {
756772
};
757773

758774
if !self.verify_user_has_enough_balance(user_balance, user_accumulated_fee, msg_max_fee) {
759-
std::mem::drop(batch_state_lock);
760775
send_message(
761776
ws_conn_sink.clone(),
762777
SubmitProofResponseMessage::InsufficientBalance(addr),
@@ -766,11 +781,10 @@ impl Batcher {
766781
return Ok(());
767782
}
768783

769-
let cached_user_nonce = batch_state_lock.get_user_nonce(&addr).await;
784+
let cached_user_nonce = self.batch_state.lock().await.get_user_nonce(&addr).await;
770785

771786
let Some(expected_nonce) = cached_user_nonce else {
772787
error!("Failed to get cached user nonce: User not found in user states, but it should have been already inserted");
773-
std::mem::drop(batch_state_lock);
774788
send_message(
775789
ws_conn_sink.clone(),
776790
SubmitProofResponseMessage::AddToBatchError,
@@ -781,7 +795,6 @@ impl Batcher {
781795
};
782796

783797
if expected_nonce < msg_nonce {
784-
std::mem::drop(batch_state_lock);
785798
warn!("Invalid nonce for address {addr}, expected nonce: {expected_nonce:?}, received nonce: {msg_nonce:?}");
786799
send_message(
787800
ws_conn_sink.clone(),
@@ -797,7 +810,6 @@ impl Batcher {
797810
if expected_nonce > msg_nonce {
798811
info!("Possible replacement message received: Expected nonce {expected_nonce:?} - message nonce: {msg_nonce:?}");
799812
self.handle_replacement_message(
800-
batch_state_lock,
801813
nonced_verification_data,
802814
ws_conn_sink.clone(),
803815
client_msg.signature,
@@ -811,7 +823,6 @@ impl Batcher {
811823
// We check this after replacement logic because if user wants to replace a proof, their
812824
// new_max_fee must be greater or equal than old_max_fee
813825
if msg_max_fee > user_last_max_fee_limit {
814-
std::mem::drop(batch_state_lock);
815826
warn!("Invalid max fee for address {addr}, had fee limit of {user_last_max_fee_limit:?}, sent {msg_max_fee:?}");
816827
send_message(
817828
ws_conn_sink.clone(),
@@ -822,13 +833,21 @@ impl Batcher {
822833
return Ok(());
823834
}
824835

825-
self.verify_proof(&nonced_verification_data);
836+
if let Err(e) = self.verify_proof(&nonced_verification_data).await {
837+
send_message(
838+
ws_conn_sink.clone(),
839+
SubmitProofResponseMessage::InvalidProof(e),
840+
)
841+
.await;
842+
return Ok(());
843+
};
826844

827845
// * ---------------------------------------------------------------------*
828846
// * Perform validation over batcher queue *
829847
// * ---------------------------------------------------------------------*
830848

831-
if batch_state_lock.is_queue_full() {
849+
if self.batch_state.lock().await.is_queue_full() {
850+
let mut batch_state_lock = self.batch_state.lock().await;
832851
debug!("Batch queue is full. Evaluating if the incoming proof can replace a lower-priority entry.");
833852

834853
// This cannot panic, if the batch queue is full it has at least one item
@@ -889,6 +908,7 @@ impl Batcher {
889908
// * Add message data into the queue and update user state *
890909
// * ---------------------------------------------------------------------*
891910

911+
let mut batch_state_lock = self.batch_state.lock().await;
892912
if let Err(e) = self
893913
.add_to_batch(
894914
batch_state_lock,
@@ -934,16 +954,20 @@ impl Batcher {
934954
/// Returns true if the message was replaced in the batch, false otherwise
935955
async fn handle_replacement_message(
936956
&self,
937-
mut batch_state_lock: MutexGuard<'_, BatchState>,
938957
nonced_verification_data: NoncedVerificationData,
939958
ws_conn_sink: WsMessageSink,
940959
signature: Signature,
941960
addr: Address,
942961
) {
943962
let replacement_max_fee = nonced_verification_data.max_fee;
944963
let nonce = nonced_verification_data.nonce;
945-
let Some(entry) = batch_state_lock.get_entry(addr, nonce) else {
946-
std::mem::drop(batch_state_lock);
964+
let Some(entry) = self
965+
.batch_state
966+
.lock()
967+
.await
968+
.get_entry(addr, nonce)
969+
.cloned()
970+
else {
947971
warn!("Invalid nonce for address {addr}. Queue entry with nonce {nonce} not found");
948972
send_message(
949973
ws_conn_sink.clone(),
@@ -956,7 +980,6 @@ impl Batcher {
956980

957981
let original_max_fee = entry.nonced_verification_data.max_fee;
958982
if original_max_fee > replacement_max_fee {
959-
std::mem::drop(batch_state_lock);
960983
warn!("Invalid replacement message for address {addr}, had max fee: {original_max_fee:?}, received fee: {replacement_max_fee:?}");
961984
send_message(
962985
ws_conn_sink.clone(),
@@ -969,7 +992,14 @@ impl Batcher {
969992
}
970993

971994
// if all went well, verify the proof
972-
self.verify_proof(&nonced_verification_data);
995+
if let Err(e) = self.verify_proof(&nonced_verification_data).await {
996+
send_message(
997+
ws_conn_sink.clone(),
998+
SubmitProofResponseMessage::InvalidProof(e),
999+
)
1000+
.await;
1001+
return;
1002+
};
9731003

9741004
info!("Replacing message for address {addr} with nonce {nonce} and max fee {replacement_max_fee}");
9751005

@@ -998,8 +1028,12 @@ impl Batcher {
9981028
}
9991029

10001030
replacement_entry.messaging_sink = Some(ws_conn_sink.clone());
1001-
if !batch_state_lock.replacement_entry_is_valid(&replacement_entry) {
1002-
std::mem::drop(batch_state_lock);
1031+
if !self
1032+
.batch_state
1033+
.lock()
1034+
.await
1035+
.replacement_entry_is_valid(&replacement_entry)
1036+
{
10031037
warn!("Invalid replacement message");
10041038
send_message(
10051039
ws_conn_sink.clone(),
@@ -1020,6 +1054,7 @@ impl Batcher {
10201054
// note that the entries are considered equal for the priority queue
10211055
// if they have the same nonce and sender, so we can remove the old entry
10221056
// by calling remove with the new entry
1057+
let mut batch_state_lock = self.batch_state.lock().await;
10231058
batch_state_lock.batch_queue.remove(&replacement_entry);
10241059
batch_state_lock.batch_queue.push(
10251060
replacement_entry.clone(),

0 commit comments

Comments
 (0)