Skip to content

Commit fdaafd5

Browse files
address_synchronizer: apply @with_lock where applicable
1 parent 4887fb3 commit fdaafd5

File tree

1 file changed

+97
-97
lines changed

1 file changed

+97
-97
lines changed

electrum/address_synchronizer.py

Lines changed: 97 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def is_mine(self, address: Optional[str]) -> bool:
128128
def get_addresses(self):
129129
return sorted(self.db.get_history())
130130

131+
@with_lock
131132
def get_address_history(self, addr: str) -> Dict[str, int]:
132133
"""Returns the history for the address, as a txid->height dict.
133134
In addition to what we have from the server, this includes local and future txns.
@@ -136,11 +137,10 @@ def get_address_history(self, addr: str) -> Dict[str, int]:
136137
so that only includes txns the server sees.
137138
"""
138139
h = {}
139-
with self.lock:
140-
related_txns = self._history_local.get(addr, set())
141-
for tx_hash in related_txns:
142-
tx_height = self.get_tx_height(tx_hash).height
143-
h[tx_hash] = tx_height
140+
related_txns = self._history_local.get(addr, set())
141+
for tx_hash in related_txns:
142+
tx_height = self.get_tx_height(tx_hash).height
143+
h[tx_hash] = tx_height
144144
return h
145145

146146
def get_address_history_len(self, addr: str) -> int:
@@ -201,10 +201,10 @@ def start_network(self, network: Optional['Network']) -> None:
201201
self.register_callbacks()
202202

203203
@event_listener
204+
@with_lock
204205
def on_event_blockchain_updated(self, *args):
205-
with self.lock:
206-
self._get_balance_cache = {} # invalidate cache
207-
self.db.put('stored_height', self.get_local_height())
206+
self._get_balance_cache = {} # invalidate cache
207+
self.db.put('stored_height', self.get_local_height())
208208

209209
async def stop(self):
210210
if self.network:
@@ -227,6 +227,7 @@ def add_address(self, address):
227227
self.synchronizer.add(address)
228228
self.up_to_date_changed()
229229

230+
@with_lock
230231
def get_conflicting_transactions(self, tx: Transaction, *, include_self: bool = False) -> Set[str]:
231232
"""Returns a set of transaction hashes from the wallet history that are
232233
directly conflicting with tx, i.e. they have common outpoints being
@@ -236,27 +237,26 @@ def get_conflicting_transactions(self, tx: Transaction, *, include_self: bool =
236237
conflict (if already in wallet history)
237238
"""
238239
conflicting_txns = set()
239-
with self.lock:
240-
for txin in tx.inputs():
241-
if txin.is_coinbase_input():
242-
continue
243-
prevout_hash = txin.prevout.txid.hex()
244-
prevout_n = txin.prevout.out_idx
245-
spending_tx_hash = self.db.get_spent_outpoint(prevout_hash, prevout_n)
246-
if spending_tx_hash is None:
247-
continue
248-
# this outpoint has already been spent, by spending_tx
249-
# annoying assert that has revealed several bugs over time:
250-
assert self.db.get_transaction(spending_tx_hash), "spending tx not in wallet db"
251-
conflicting_txns |= {spending_tx_hash}
252-
if tx_hash := tx.txid():
253-
if tx_hash in conflicting_txns:
254-
# this tx is already in history, so it conflicts with itself
255-
if len(conflicting_txns) > 1:
256-
raise Exception('Found conflicting transactions already in wallet history.')
257-
if not include_self:
258-
conflicting_txns -= {tx_hash}
259-
return conflicting_txns
240+
for txin in tx.inputs():
241+
if txin.is_coinbase_input():
242+
continue
243+
prevout_hash = txin.prevout.txid.hex()
244+
prevout_n = txin.prevout.out_idx
245+
spending_tx_hash = self.db.get_spent_outpoint(prevout_hash, prevout_n)
246+
if spending_tx_hash is None:
247+
continue
248+
# this outpoint has already been spent, by spending_tx
249+
# annoying assert that has revealed several bugs over time:
250+
assert self.db.get_transaction(spending_tx_hash), "spending tx not in wallet db"
251+
conflicting_txns |= {spending_tx_hash}
252+
if tx_hash := tx.txid():
253+
if tx_hash in conflicting_txns:
254+
# this tx is already in history, so it conflicts with itself
255+
if len(conflicting_txns) > 1:
256+
raise Exception('Found conflicting transactions already in wallet history.')
257+
if not include_self:
258+
conflicting_txns -= {tx_hash}
259+
return conflicting_txns
260260

261261
@with_lock
262262
def get_transaction(self, txid: str) -> Optional[Transaction]:
@@ -368,15 +368,15 @@ def add_value_from_prev_output():
368368
util.trigger_callback('adb_added_tx', self, tx_hash, tx)
369369
return True
370370

371+
@with_lock
371372
def remove_transaction(self, tx_hash: str) -> None:
372373
"""Removes a transaction AND all its dependents/children
373374
from the wallet history.
374375
"""
375-
with self.lock:
376-
to_remove = {tx_hash}
377-
to_remove |= self.get_depending_transactions(tx_hash)
378-
for txid in to_remove:
379-
self._remove_transaction(txid)
376+
to_remove = {tx_hash}
377+
to_remove |= self.get_depending_transactions(tx_hash)
378+
for txid in to_remove:
379+
self._remove_transaction(txid)
380380

381381
def _remove_transaction(self, tx_hash: str) -> None:
382382
"""Removes a single transaction from the wallet history, and attempts
@@ -419,15 +419,15 @@ def remove_from_spent_outpoints():
419419
self.db.remove_prevout_by_scripthash(scripthash, prevout=prevout, value=txo.value)
420420
util.trigger_callback('adb_removed_tx', self, tx_hash, tx)
421421

422+
@with_lock
422423
def get_depending_transactions(self, tx_hash: str) -> Set[str]:
423424
"""Returns all (grand-)children of tx_hash in this wallet."""
424-
with self.lock:
425-
children = set()
426-
for n in self.db.get_spent_outpoints(tx_hash):
427-
other_hash = self.db.get_spent_outpoint(tx_hash, n)
428-
children.add(other_hash)
429-
children |= self.get_depending_transactions(other_hash)
430-
return children
425+
children = set()
426+
for n in self.db.get_spent_outpoints(tx_hash):
427+
other_hash = self.db.get_spent_outpoint(tx_hash, n)
428+
children.add(other_hash)
429+
children |= self.get_depending_transactions(other_hash)
430+
return children
431431

432432
@with_lock
433433
def receive_tx_callback(self, tx: Transaction, *, tx_height: Optional[int] = None) -> None:
@@ -499,19 +499,19 @@ def remove_local_transactions_we_dont_have(self):
499499
if tx_height == TX_HEIGHT_LOCAL and not self.db.get_transaction(txid):
500500
self.remove_transaction(txid)
501501

502+
@with_lock
502503
def clear_history(self):
503-
with self.lock:
504-
self.db.clear_history()
505-
self._history_local.clear()
506-
self._get_balance_cache.clear() # invalidate cache
504+
self.db.clear_history()
505+
self._history_local.clear()
506+
self._get_balance_cache.clear() # invalidate cache
507507

508+
@with_lock
508509
def _get_tx_sort_key(self, tx_hash: str) -> Tuple[int, int]:
509510
"""Returns a key to be used for sorting txs."""
510-
with self.lock:
511-
tx_mined_info = self.get_tx_height(tx_hash)
512-
height = self.tx_height_to_sort_height(tx_mined_info.height)
513-
txpos = tx_mined_info.txpos or -1
514-
return height, txpos
511+
tx_mined_info = self.get_tx_height(tx_hash)
512+
height = self.tx_height_to_sort_height(tx_mined_info.height)
513+
txpos = tx_mined_info.txpos or -1
514+
return height, txpos
515515

516516
@classmethod
517517
def tx_height_to_sort_height(cls, height: int = None):
@@ -578,25 +578,25 @@ def get_history(self, domain) -> Sequence[HistoryItem]:
578578
raise Exception("wallet.get_history() failed balance sanity-check")
579579
return h2
580580

581+
@with_lock
581582
def _add_tx_to_local_history(self, txid):
582-
with self.lock:
583-
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
584-
cur_hist = self._history_local.get(addr, set())
585-
cur_hist.add(txid)
586-
self._history_local[addr] = cur_hist
587-
self._mark_address_history_changed(addr)
583+
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
584+
cur_hist = self._history_local.get(addr, set())
585+
cur_hist.add(txid)
586+
self._history_local[addr] = cur_hist
587+
self._mark_address_history_changed(addr)
588588

589+
@with_lock
589590
def _remove_tx_from_local_history(self, txid):
590-
with self.lock:
591-
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
592-
cur_hist = self._history_local.get(addr, set())
593-
try:
594-
cur_hist.remove(txid)
595-
except KeyError:
596-
pass
597-
else:
598-
self._history_local[addr] = cur_hist
599-
self._mark_address_history_changed(addr)
591+
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
592+
cur_hist = self._history_local.get(addr, set())
593+
try:
594+
cur_hist.remove(txid)
595+
except KeyError:
596+
pass
597+
else:
598+
self._history_local[addr] = cur_hist
599+
self._mark_address_history_changed(addr)
600600

601601
def _mark_address_history_changed(self, addr: str) -> None:
602602
def set_and_clear():
@@ -617,27 +617,27 @@ async def wait_for_address_history_to_change(self, addr: str) -> None:
617617
assert self.is_mine(addr), "address needs to be is_mine to be watched"
618618
await self._address_history_changed_events[addr].wait()
619619

620+
@with_lock
620621
def add_unverified_or_unconfirmed_tx(self, tx_hash: str, tx_height: int) -> None:
621622
assert tx_height >= TX_HEIGHT_UNCONF_PARENT, f"got {tx_height=} for {tx_hash=}" # forbid local/future txs here
622-
with self.lock:
623-
if self.db.is_in_verified_tx(tx_hash):
624-
if tx_height <= 0:
625-
# tx was previously SPV-verified but now in mempool (probably reorg)
626-
self.db.remove_verified_tx(tx_hash)
627-
self.unconfirmed_tx[tx_hash] = tx_height
628-
if self.verifier:
629-
self.verifier.remove_spv_proof_for_tx(tx_hash)
623+
if self.db.is_in_verified_tx(tx_hash):
624+
if tx_height <= 0:
625+
# tx was previously SPV-verified but now in mempool (probably reorg)
626+
self.db.remove_verified_tx(tx_hash)
627+
self.unconfirmed_tx[tx_hash] = tx_height
628+
if self.verifier:
629+
self.verifier.remove_spv_proof_for_tx(tx_hash)
630+
else:
631+
if tx_height > 0:
632+
self.unverified_tx[tx_hash] = tx_height
630633
else:
631-
if tx_height > 0:
632-
self.unverified_tx[tx_hash] = tx_height
633-
else:
634-
self.unconfirmed_tx[tx_hash] = tx_height
634+
self.unconfirmed_tx[tx_hash] = tx_height
635635

636+
@with_lock
636637
def remove_unverified_tx(self, tx_hash: str, tx_height: int) -> None:
637-
with self.lock:
638-
new_height = self.unverified_tx.get(tx_hash)
639-
if new_height == tx_height:
640-
self.unverified_tx.pop(tx_hash, None)
638+
new_height = self.unverified_tx.get(tx_hash)
639+
if new_height == tx_height:
640+
self.unverified_tx.pop(tx_hash, None)
641641

642642
def add_verified_tx(self, tx_hash: str, info: TxMinedInfo):
643643
# Remove from the unverified map and add to the verified map
@@ -646,10 +646,10 @@ def add_verified_tx(self, tx_hash: str, info: TxMinedInfo):
646646
self.db.add_verified_tx(tx_hash, info)
647647
util.trigger_callback('adb_added_verified_tx', self, tx_hash)
648648

649+
@with_lock
649650
def get_unverified_txs(self) -> Dict[str, int]:
650651
'''Returns a map from tx hash to transaction height'''
651-
with self.lock:
652-
return dict(self.unverified_tx) # copy
652+
return dict(self.unverified_tx) # copy
653653

654654
def undo_verifications(self, blockchain: Blockchain, above_height: int) -> Set[str]:
655655
'''Used by the verifier when a reorg has happened'''
@@ -830,20 +830,20 @@ def get_tx_fee(self, txid: str) -> Optional[int]:
830830
self.db.add_num_inputs_to_tx(txid, len(tx.inputs()))
831831
return fee
832832

833+
@with_lock
833834
def get_addr_io(self, address: str):
834-
with self.lock:
835-
h = self.get_address_history(address).items()
836-
received = {}
837-
sent = {}
838-
for tx_hash, height in h:
839-
tx_mined_info = self.get_tx_height(tx_hash)
840-
txpos = tx_mined_info.txpos if tx_mined_info.txpos is not None else -1
841-
d = self.db.get_txo_addr(tx_hash, address)
842-
for n, (v, is_cb) in d.items():
843-
received[tx_hash + ':%d'%n] = (height, txpos, v, is_cb)
844-
l = self.db.get_txi_addr(tx_hash, address)
845-
for txi, v in l:
846-
sent[txi] = tx_hash, height, txpos
835+
h = self.get_address_history(address).items()
836+
received = {}
837+
sent = {}
838+
for tx_hash, height in h:
839+
tx_mined_info = self.get_tx_height(tx_hash)
840+
txpos = tx_mined_info.txpos if tx_mined_info.txpos is not None else -1
841+
d = self.db.get_txo_addr(tx_hash, address)
842+
for n, (v, is_cb) in d.items():
843+
received[tx_hash + ':%d'%n] = (height, txpos, v, is_cb)
844+
l = self.db.get_txi_addr(tx_hash, address)
845+
for txi, v in l:
846+
sent[txi] = tx_hash, height, txpos
847847
return received, sent
848848

849849
def get_addr_outputs(self, address: str) -> Dict[TxOutpoint, PartialTxInput]:

0 commit comments

Comments
 (0)