Skip to content

Commit 56d5025

Browse files
authored
Merge pull request #9958 from nabijaczleweli/cachet
address_synchronizer: add a cache in front of get_utxos()
2 parents 4887fb3 + e28836e commit 56d5025

File tree

1 file changed

+115
-100
lines changed

1 file changed

+115
-100
lines changed

electrum/address_synchronizer.py

Lines changed: 115 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# SOFTWARE.
2323

2424
import asyncio
25+
import copy
2526
import threading
2627
import itertools
2728
from collections import defaultdict
@@ -99,9 +100,15 @@ def __init__(self, db: 'WalletDB', config: 'SimpleConfig', *, name: str = None):
99100
self.threadlocal_cache = threading.local()
100101

101102
self._get_balance_cache = {}
103+
self._get_utxos_cache = {}
102104

103105
self.load_and_cleanup()
104106

107+
@with_lock
108+
def invalidate_cache(self):
109+
self._get_balance_cache.clear()
110+
self._get_utxos_cache.clear()
111+
105112
def diagnostic_name(self):
106113
return self.name or ""
107114

@@ -128,6 +135,7 @@ def is_mine(self, address: Optional[str]) -> bool:
128135
def get_addresses(self):
129136
return sorted(self.db.get_history())
130137

138+
@with_lock
131139
def get_address_history(self, addr: str) -> Dict[str, int]:
132140
"""Returns the history for the address, as a txid->height dict.
133141
In addition to what we have from the server, this includes local and future txns.
@@ -136,11 +144,10 @@ def get_address_history(self, addr: str) -> Dict[str, int]:
136144
so that only includes txns the server sees.
137145
"""
138146
h = {}
139-
with self.lock:
140-
related_txns = self._history_local.get(addr, set())
141-
for tx_hash in related_txns:
142-
tx_height = self.get_tx_height(tx_hash).height
143-
h[tx_hash] = tx_height
147+
related_txns = self._history_local.get(addr, set())
148+
for tx_hash in related_txns:
149+
tx_height = self.get_tx_height(tx_hash).height
150+
h[tx_hash] = tx_height
144151
return h
145152

146153
def get_address_history_len(self, addr: str) -> int:
@@ -201,10 +208,10 @@ def start_network(self, network: Optional['Network']) -> None:
201208
self.register_callbacks()
202209

203210
@event_listener
211+
@with_lock
204212
def on_event_blockchain_updated(self, *args):
205-
with self.lock:
206-
self._get_balance_cache = {} # invalidate cache
207-
self.db.put('stored_height', self.get_local_height())
213+
self.invalidate_cache()
214+
self.db.put('stored_height', self.get_local_height())
208215

209216
async def stop(self):
210217
if self.network:
@@ -227,6 +234,7 @@ def add_address(self, address):
227234
self.synchronizer.add(address)
228235
self.up_to_date_changed()
229236

237+
@with_lock
230238
def get_conflicting_transactions(self, tx: Transaction, *, include_self: bool = False) -> Set[str]:
231239
"""Returns a set of transaction hashes from the wallet history that are
232240
directly conflicting with tx, i.e. they have common outpoints being
@@ -236,27 +244,26 @@ def get_conflicting_transactions(self, tx: Transaction, *, include_self: bool =
236244
conflict (if already in wallet history)
237245
"""
238246
conflicting_txns = set()
239-
with self.lock:
240-
for txin in tx.inputs():
241-
if txin.is_coinbase_input():
242-
continue
243-
prevout_hash = txin.prevout.txid.hex()
244-
prevout_n = txin.prevout.out_idx
245-
spending_tx_hash = self.db.get_spent_outpoint(prevout_hash, prevout_n)
246-
if spending_tx_hash is None:
247-
continue
248-
# this outpoint has already been spent, by spending_tx
249-
# annoying assert that has revealed several bugs over time:
250-
assert self.db.get_transaction(spending_tx_hash), "spending tx not in wallet db"
251-
conflicting_txns |= {spending_tx_hash}
252-
if tx_hash := tx.txid():
253-
if tx_hash in conflicting_txns:
254-
# this tx is already in history, so it conflicts with itself
255-
if len(conflicting_txns) > 1:
256-
raise Exception('Found conflicting transactions already in wallet history.')
257-
if not include_self:
258-
conflicting_txns -= {tx_hash}
259-
return conflicting_txns
247+
for txin in tx.inputs():
248+
if txin.is_coinbase_input():
249+
continue
250+
prevout_hash = txin.prevout.txid.hex()
251+
prevout_n = txin.prevout.out_idx
252+
spending_tx_hash = self.db.get_spent_outpoint(prevout_hash, prevout_n)
253+
if spending_tx_hash is None:
254+
continue
255+
# this outpoint has already been spent, by spending_tx
256+
# annoying assert that has revealed several bugs over time:
257+
assert self.db.get_transaction(spending_tx_hash), "spending tx not in wallet db"
258+
conflicting_txns |= {spending_tx_hash}
259+
if tx_hash := tx.txid():
260+
if tx_hash in conflicting_txns:
261+
# this tx is already in history, so it conflicts with itself
262+
if len(conflicting_txns) > 1:
263+
raise Exception('Found conflicting transactions already in wallet history.')
264+
if not include_self:
265+
conflicting_txns -= {tx_hash}
266+
return conflicting_txns
260267

261268
@with_lock
262269
def get_transaction(self, txid: str) -> Optional[Transaction]:
@@ -335,7 +342,7 @@ def add_value_from_prev_output():
335342
pass
336343
else:
337344
self.db.add_txi_addr(tx_hash, addr, ser, v)
338-
self._get_balance_cache.clear() # invalidate cache
345+
self.invalidate_cache()
339346
for txi in tx.inputs():
340347
if txi.is_coinbase_input():
341348
continue
@@ -353,7 +360,7 @@ def add_value_from_prev_output():
353360
addr = txo.address
354361
if addr and self.is_mine(addr):
355362
self.db.add_txo_addr(tx_hash, addr, n, v, is_coinbase)
356-
self._get_balance_cache.clear() # invalidate cache
363+
self.invalidate_cache()
357364
# give v to txi that spends me
358365
next_tx = self.db.get_spent_outpoint(tx_hash, n)
359366
if next_tx is not None:
@@ -368,15 +375,15 @@ def add_value_from_prev_output():
368375
util.trigger_callback('adb_added_tx', self, tx_hash, tx)
369376
return True
370377

378+
@with_lock
371379
def remove_transaction(self, tx_hash: str) -> None:
372380
"""Removes a transaction AND all its dependents/children
373381
from the wallet history.
374382
"""
375-
with self.lock:
376-
to_remove = {tx_hash}
377-
to_remove |= self.get_depending_transactions(tx_hash)
378-
for txid in to_remove:
379-
self._remove_transaction(txid)
383+
to_remove = {tx_hash}
384+
to_remove |= self.get_depending_transactions(tx_hash)
385+
for txid in to_remove:
386+
self._remove_transaction(txid)
380387

381388
def _remove_transaction(self, tx_hash: str) -> None:
382389
"""Removes a single transaction from the wallet history, and attempts
@@ -405,7 +412,7 @@ def remove_from_spent_outpoints():
405412
remove_from_spent_outpoints()
406413
self._remove_tx_from_local_history(tx_hash)
407414
for addr in itertools.chain(self.db.get_txi_addresses(tx_hash), self.db.get_txo_addresses(tx_hash)):
408-
self._get_balance_cache.clear() # invalidate cache
415+
self.invalidate_cache()
409416
self.db.remove_txi(tx_hash)
410417
self.db.remove_txo(tx_hash)
411418
self.db.remove_tx_fee(tx_hash)
@@ -419,15 +426,15 @@ def remove_from_spent_outpoints():
419426
self.db.remove_prevout_by_scripthash(scripthash, prevout=prevout, value=txo.value)
420427
util.trigger_callback('adb_removed_tx', self, tx_hash, tx)
421428

429+
@with_lock
422430
def get_depending_transactions(self, tx_hash: str) -> Set[str]:
423431
"""Returns all (grand-)children of tx_hash in this wallet."""
424-
with self.lock:
425-
children = set()
426-
for n in self.db.get_spent_outpoints(tx_hash):
427-
other_hash = self.db.get_spent_outpoint(tx_hash, n)
428-
children.add(other_hash)
429-
children |= self.get_depending_transactions(other_hash)
430-
return children
432+
children = set()
433+
for n in self.db.get_spent_outpoints(tx_hash):
434+
other_hash = self.db.get_spent_outpoint(tx_hash, n)
435+
children.add(other_hash)
436+
children |= self.get_depending_transactions(other_hash)
437+
return children
431438

432439
@with_lock
433440
def receive_tx_callback(self, tx: Transaction, *, tx_height: Optional[int] = None) -> None:
@@ -499,19 +506,19 @@ def remove_local_transactions_we_dont_have(self):
499506
if tx_height == TX_HEIGHT_LOCAL and not self.db.get_transaction(txid):
500507
self.remove_transaction(txid)
501508

509+
@with_lock
502510
def clear_history(self):
503-
with self.lock:
504-
self.db.clear_history()
505-
self._history_local.clear()
506-
self._get_balance_cache.clear() # invalidate cache
511+
self.db.clear_history()
512+
self._history_local.clear()
513+
self.invalidate_cache()
507514

515+
@with_lock
508516
def _get_tx_sort_key(self, tx_hash: str) -> Tuple[int, int]:
509517
"""Returns a key to be used for sorting txs."""
510-
with self.lock:
511-
tx_mined_info = self.get_tx_height(tx_hash)
512-
height = self.tx_height_to_sort_height(tx_mined_info.height)
513-
txpos = tx_mined_info.txpos or -1
514-
return height, txpos
518+
tx_mined_info = self.get_tx_height(tx_hash)
519+
height = self.tx_height_to_sort_height(tx_mined_info.height)
520+
txpos = tx_mined_info.txpos or -1
521+
return height, txpos
515522

516523
@classmethod
517524
def tx_height_to_sort_height(cls, height: int = None):
@@ -578,25 +585,25 @@ def get_history(self, domain) -> Sequence[HistoryItem]:
578585
raise Exception("wallet.get_history() failed balance sanity-check")
579586
return h2
580587

588+
@with_lock
581589
def _add_tx_to_local_history(self, txid):
582-
with self.lock:
583-
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
584-
cur_hist = self._history_local.get(addr, set())
585-
cur_hist.add(txid)
586-
self._history_local[addr] = cur_hist
587-
self._mark_address_history_changed(addr)
590+
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
591+
cur_hist = self._history_local.get(addr, set())
592+
cur_hist.add(txid)
593+
self._history_local[addr] = cur_hist
594+
self._mark_address_history_changed(addr)
588595

596+
@with_lock
589597
def _remove_tx_from_local_history(self, txid):
590-
with self.lock:
591-
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
592-
cur_hist = self._history_local.get(addr, set())
593-
try:
594-
cur_hist.remove(txid)
595-
except KeyError:
596-
pass
597-
else:
598-
self._history_local[addr] = cur_hist
599-
self._mark_address_history_changed(addr)
598+
for addr in itertools.chain(self.db.get_txi_addresses(txid), self.db.get_txo_addresses(txid)):
599+
cur_hist = self._history_local.get(addr, set())
600+
try:
601+
cur_hist.remove(txid)
602+
except KeyError:
603+
pass
604+
else:
605+
self._history_local[addr] = cur_hist
606+
self._mark_address_history_changed(addr)
600607

601608
def _mark_address_history_changed(self, addr: str) -> None:
602609
def set_and_clear():
@@ -617,27 +624,27 @@ async def wait_for_address_history_to_change(self, addr: str) -> None:
617624
assert self.is_mine(addr), "address needs to be is_mine to be watched"
618625
await self._address_history_changed_events[addr].wait()
619626

627+
@with_lock
620628
def add_unverified_or_unconfirmed_tx(self, tx_hash: str, tx_height: int) -> None:
621629
assert tx_height >= TX_HEIGHT_UNCONF_PARENT, f"got {tx_height=} for {tx_hash=}" # forbid local/future txs here
622-
with self.lock:
623-
if self.db.is_in_verified_tx(tx_hash):
624-
if tx_height <= 0:
625-
# tx was previously SPV-verified but now in mempool (probably reorg)
626-
self.db.remove_verified_tx(tx_hash)
627-
self.unconfirmed_tx[tx_hash] = tx_height
628-
if self.verifier:
629-
self.verifier.remove_spv_proof_for_tx(tx_hash)
630+
if self.db.is_in_verified_tx(tx_hash):
631+
if tx_height <= 0:
632+
# tx was previously SPV-verified but now in mempool (probably reorg)
633+
self.db.remove_verified_tx(tx_hash)
634+
self.unconfirmed_tx[tx_hash] = tx_height
635+
if self.verifier:
636+
self.verifier.remove_spv_proof_for_tx(tx_hash)
637+
else:
638+
if tx_height > 0:
639+
self.unverified_tx[tx_hash] = tx_height
630640
else:
631-
if tx_height > 0:
632-
self.unverified_tx[tx_hash] = tx_height
633-
else:
634-
self.unconfirmed_tx[tx_hash] = tx_height
641+
self.unconfirmed_tx[tx_hash] = tx_height
635642

643+
@with_lock
636644
def remove_unverified_tx(self, tx_hash: str, tx_height: int) -> None:
637-
with self.lock:
638-
new_height = self.unverified_tx.get(tx_hash)
639-
if new_height == tx_height:
640-
self.unverified_tx.pop(tx_hash, None)
645+
new_height = self.unverified_tx.get(tx_hash)
646+
if new_height == tx_height:
647+
self.unverified_tx.pop(tx_hash, None)
641648

642649
def add_verified_tx(self, tx_hash: str, info: TxMinedInfo):
643650
# Remove from the unverified map and add to the verified map
@@ -646,10 +653,10 @@ def add_verified_tx(self, tx_hash: str, info: TxMinedInfo):
646653
self.db.add_verified_tx(tx_hash, info)
647654
util.trigger_callback('adb_added_verified_tx', self, tx_hash)
648655

656+
@with_lock
649657
def get_unverified_txs(self) -> Dict[str, int]:
650658
'''Returns a map from tx hash to transaction height'''
651-
with self.lock:
652-
return dict(self.unverified_tx) # copy
659+
return dict(self.unverified_tx) # copy
653660

654661
def undo_verifications(self, blockchain: Blockchain, above_height: int) -> Set[str]:
655662
'''Used by the verifier when a reorg has happened'''
@@ -830,20 +837,20 @@ def get_tx_fee(self, txid: str) -> Optional[int]:
830837
self.db.add_num_inputs_to_tx(txid, len(tx.inputs()))
831838
return fee
832839

840+
@with_lock
833841
def get_addr_io(self, address: str):
834-
with self.lock:
835-
h = self.get_address_history(address).items()
836-
received = {}
837-
sent = {}
838-
for tx_hash, height in h:
839-
tx_mined_info = self.get_tx_height(tx_hash)
840-
txpos = tx_mined_info.txpos if tx_mined_info.txpos is not None else -1
841-
d = self.db.get_txo_addr(tx_hash, address)
842-
for n, (v, is_cb) in d.items():
843-
received[tx_hash + ':%d'%n] = (height, txpos, v, is_cb)
844-
l = self.db.get_txi_addr(tx_hash, address)
845-
for txi, v in l:
846-
sent[txi] = tx_hash, height, txpos
842+
h = self.get_address_history(address).items()
843+
received = {}
844+
sent = {}
845+
for tx_hash, height in h:
846+
tx_mined_info = self.get_tx_height(tx_hash)
847+
txpos = tx_mined_info.txpos if tx_mined_info.txpos is not None else -1
848+
d = self.db.get_txo_addr(tx_hash, address)
849+
for n, (v, is_cb) in d.items():
850+
received[tx_hash + ':%d'%n] = (height, txpos, v, is_cb)
851+
l = self.db.get_txi_addr(tx_hash, address)
852+
for txi, v in l:
853+
sent[txi] = tx_hash, height, txpos
847854
return received, sent
848855

849856
def get_addr_outputs(self, address: str) -> Dict[TxOutpoint, PartialTxInput]:
@@ -970,6 +977,13 @@ def get_utxos(
970977
if excluded_addresses:
971978
domain = set(domain) - set(excluded_addresses)
972979
mempool_height = block_height + 1 # height of next block
980+
cache_key = sha256(
981+
','.join(sorted(domain))
982+
+ f";{mature_only};{confirmed_funding_only};{confirmed_spending_only};{nonlocal_only};{block_height}"
983+
)
984+
cached = self._get_utxos_cache.get(cache_key)
985+
if cached is not None:
986+
return copy.deepcopy(cached)
973987
for addr in domain:
974988
txos = self.get_addr_outputs(addr)
975989
for txo in txos.values():
@@ -987,6 +1001,7 @@ def get_utxos(
9871001
continue
9881002
coins.append(txo)
9891003
continue
1004+
self._get_utxos_cache[cache_key] = copy.deepcopy(coins)
9901005
return coins
9911006

9921007
def is_used(self, address: str) -> bool:

0 commit comments

Comments
 (0)