@@ -86,6 +86,7 @@ pub struct Batcher {
8686 service_manager : ServiceManager ,
8787 service_manager_fallback : ServiceManager ,
8888 batch_state : Mutex < BatchState > ,
89+ user_mutexes : Mutex < HashMap < Address , Arc < Mutex < ( ) > > > > ,
8990 min_block_interval : u64 ,
9091 transaction_wait_timeout : u64 ,
9192 max_proof_size : usize ,
@@ -276,6 +277,7 @@ impl Batcher {
276277 aggregator_gas_cost : config. batcher . aggregator_gas_cost ,
277278 posting_batch : Mutex :: new ( false ) ,
278279 batch_state : Mutex :: new ( batch_state) ,
280+ user_mutexes : Mutex :: new ( HashMap :: new ( ) ) ,
279281 disabled_verifiers : Mutex :: new ( disabled_verifiers) ,
280282 metrics,
281283 telemetry,
@@ -679,6 +681,7 @@ impl Batcher {
679681 // If it was not present, then the user nonce is queried to the Aligned contract.
680682 // Lastly, we get a lock of the batch state again and insert the user state if it was still missing.
681683
684+ // Step 1: Get or insert the per-address mutex under lock
682685 let is_user_in_state: bool = {
683686 let batch_state_lock = self . batch_state . lock ( ) . await ;
684687 batch_state_lock. user_states . contains_key ( & addr)
@@ -727,13 +730,22 @@ impl Batcher {
727730 // This is needed because we need to query the user state to make validations and
728731 // finally add the proof to the batch queue.
729732
730- let mut batch_state_lock = self . batch_state . lock ( ) . await ;
733+ let user_mutex = {
734+ let mut map = self . user_mutexes . lock ( ) . await ;
735+ map. entry ( addr)
736+ . or_insert_with ( || Arc :: new ( Mutex :: new ( ( ) ) ) )
737+ . clone ( )
738+ } ;
739+ let _ = user_mutex. lock ( ) . await ;
731740
732741 let msg_max_fee = nonced_verification_data. max_fee ;
733- let Some ( user_last_max_fee_limit) =
734- batch_state_lock. get_user_last_max_fee_limit ( & addr) . await
742+ let Some ( user_last_max_fee_limit) = self
743+ . batch_state
744+ . lock ( )
745+ . await
746+ . get_user_last_max_fee_limit ( & addr)
747+ . await
735748 else {
736- std:: mem:: drop ( batch_state_lock) ;
737749 send_message (
738750 ws_conn_sink. clone ( ) ,
739751 SubmitProofResponseMessage :: AddToBatchError ,
@@ -743,9 +755,13 @@ impl Batcher {
743755 return Ok ( ( ) ) ;
744756 } ;
745757
746- let Some ( user_accumulated_fee) = batch_state_lock. get_user_total_fees_in_queue ( & addr) . await
758+ let Some ( user_accumulated_fee) = self
759+ . batch_state
760+ . lock ( )
761+ . await
762+ . get_user_total_fees_in_queue ( & addr)
763+ . await
747764 else {
748- std:: mem:: drop ( batch_state_lock) ;
749765 send_message (
750766 ws_conn_sink. clone ( ) ,
751767 SubmitProofResponseMessage :: AddToBatchError ,
@@ -756,7 +772,6 @@ impl Batcher {
756772 } ;
757773
758774 if !self . verify_user_has_enough_balance ( user_balance, user_accumulated_fee, msg_max_fee) {
759- std:: mem:: drop ( batch_state_lock) ;
760775 send_message (
761776 ws_conn_sink. clone ( ) ,
762777 SubmitProofResponseMessage :: InsufficientBalance ( addr) ,
@@ -766,11 +781,10 @@ impl Batcher {
766781 return Ok ( ( ) ) ;
767782 }
768783
769- let cached_user_nonce = batch_state_lock . get_user_nonce ( & addr) . await ;
784+ let cached_user_nonce = self . batch_state . lock ( ) . await . get_user_nonce ( & addr) . await ;
770785
771786 let Some ( expected_nonce) = cached_user_nonce else {
772787 error ! ( "Failed to get cached user nonce: User not found in user states, but it should have been already inserted" ) ;
773- std:: mem:: drop ( batch_state_lock) ;
774788 send_message (
775789 ws_conn_sink. clone ( ) ,
776790 SubmitProofResponseMessage :: AddToBatchError ,
@@ -781,7 +795,6 @@ impl Batcher {
781795 } ;
782796
783797 if expected_nonce < msg_nonce {
784- std:: mem:: drop ( batch_state_lock) ;
785798 warn ! ( "Invalid nonce for address {addr}, expected nonce: {expected_nonce:?}, received nonce: {msg_nonce:?}" ) ;
786799 send_message (
787800 ws_conn_sink. clone ( ) ,
@@ -797,7 +810,6 @@ impl Batcher {
797810 if expected_nonce > msg_nonce {
798811 info ! ( "Possible replacement message received: Expected nonce {expected_nonce:?} - message nonce: {msg_nonce:?}" ) ;
799812 self . handle_replacement_message (
800- batch_state_lock,
801813 nonced_verification_data,
802814 ws_conn_sink. clone ( ) ,
803815 client_msg. signature ,
@@ -811,7 +823,6 @@ impl Batcher {
811823 // We check this after replacement logic because if user wants to replace a proof, their
812824 // new_max_fee must be greater or equal than old_max_fee
813825 if msg_max_fee > user_last_max_fee_limit {
814- std:: mem:: drop ( batch_state_lock) ;
815826 warn ! ( "Invalid max fee for address {addr}, had fee limit of {user_last_max_fee_limit:?}, sent {msg_max_fee:?}" ) ;
816827 send_message (
817828 ws_conn_sink. clone ( ) ,
@@ -822,13 +833,21 @@ impl Batcher {
822833 return Ok ( ( ) ) ;
823834 }
824835
825- self . verify_proof ( & nonced_verification_data) ;
836+ if let Err ( e) = self . verify_proof ( & nonced_verification_data) . await {
837+ send_message (
838+ ws_conn_sink. clone ( ) ,
839+ SubmitProofResponseMessage :: InvalidProof ( e) ,
840+ )
841+ . await ;
842+ return Ok ( ( ) ) ;
843+ } ;
826844
827845 // * ---------------------------------------------------------------------*
828846 // * Perform validation over batcher queue *
829847 // * ---------------------------------------------------------------------*
830848
831- if batch_state_lock. is_queue_full ( ) {
849+ if self . batch_state . lock ( ) . await . is_queue_full ( ) {
850+ let mut batch_state_lock = self . batch_state . lock ( ) . await ;
832851 debug ! ( "Batch queue is full. Evaluating if the incoming proof can replace a lower-priority entry." ) ;
833852
834853 // This cannot panic, if the batch queue is full it has at least one item
@@ -889,6 +908,7 @@ impl Batcher {
889908 // * Add message data into the queue and update user state *
890909 // * ---------------------------------------------------------------------*
891910
911+ let mut batch_state_lock = self . batch_state . lock ( ) . await ;
892912 if let Err ( e) = self
893913 . add_to_batch (
894914 batch_state_lock,
@@ -934,16 +954,20 @@ impl Batcher {
934954 /// Returns true if the message was replaced in the batch, false otherwise
935955 async fn handle_replacement_message (
936956 & self ,
937- mut batch_state_lock : MutexGuard < ' _ , BatchState > ,
938957 nonced_verification_data : NoncedVerificationData ,
939958 ws_conn_sink : WsMessageSink ,
940959 signature : Signature ,
941960 addr : Address ,
942961 ) {
943962 let replacement_max_fee = nonced_verification_data. max_fee ;
944963 let nonce = nonced_verification_data. nonce ;
945- let Some ( entry) = batch_state_lock. get_entry ( addr, nonce) else {
946- std:: mem:: drop ( batch_state_lock) ;
964+ let Some ( entry) = self
965+ . batch_state
966+ . lock ( )
967+ . await
968+ . get_entry ( addr, nonce)
969+ . cloned ( )
970+ else {
947971 warn ! ( "Invalid nonce for address {addr}. Queue entry with nonce {nonce} not found" ) ;
948972 send_message (
949973 ws_conn_sink. clone ( ) ,
@@ -956,7 +980,6 @@ impl Batcher {
956980
957981 let original_max_fee = entry. nonced_verification_data . max_fee ;
958982 if original_max_fee > replacement_max_fee {
959- std:: mem:: drop ( batch_state_lock) ;
960983 warn ! ( "Invalid replacement message for address {addr}, had max fee: {original_max_fee:?}, received fee: {replacement_max_fee:?}" ) ;
961984 send_message (
962985 ws_conn_sink. clone ( ) ,
@@ -969,7 +992,14 @@ impl Batcher {
969992 }
970993
971994 // if all went well, verify the proof
972- self . verify_proof ( & nonced_verification_data) ;
995+ if let Err ( e) = self . verify_proof ( & nonced_verification_data) . await {
996+ send_message (
997+ ws_conn_sink. clone ( ) ,
998+ SubmitProofResponseMessage :: InvalidProof ( e) ,
999+ )
1000+ . await ;
1001+ return ;
1002+ } ;
9731003
9741004 info ! ( "Replacing message for address {addr} with nonce {nonce} and max fee {replacement_max_fee}" ) ;
9751005
@@ -998,8 +1028,12 @@ impl Batcher {
9981028 }
9991029
10001030 replacement_entry. messaging_sink = Some ( ws_conn_sink. clone ( ) ) ;
1001- if !batch_state_lock. replacement_entry_is_valid ( & replacement_entry) {
1002- std:: mem:: drop ( batch_state_lock) ;
1031+ if !self
1032+ . batch_state
1033+ . lock ( )
1034+ . await
1035+ . replacement_entry_is_valid ( & replacement_entry)
1036+ {
10031037 warn ! ( "Invalid replacement message" ) ;
10041038 send_message (
10051039 ws_conn_sink. clone ( ) ,
@@ -1020,6 +1054,7 @@ impl Batcher {
10201054 // note that the entries are considered equal for the priority queue
10211055 // if they have the same nonce and sender, so we can remove the old entry
10221056 // by calling remove with the new entry
1057+ let mut batch_state_lock = self . batch_state . lock ( ) . await ;
10231058 batch_state_lock. batch_queue . remove ( & replacement_entry) ;
10241059 batch_state_lock. batch_queue . push (
10251060 replacement_entry. clone ( ) ,
0 commit comments