Skip to content

Commit 0a414ad

Browse files
authored
Merge pull request #10290 from SomberNight/202510_synchronizer_guess_status1
synchronizer: rm redundant get_history call if new block mined unconf
2 parents f3420aa + 6d016d7 commit 0a414ad

File tree

2 files changed

+177
-19
lines changed

2 files changed

+177
-19
lines changed

electrum/synchronizer.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# SOFTWARE.
2525
import asyncio
2626
import hashlib
27-
from typing import Dict, List, TYPE_CHECKING, Tuple, Set
27+
from typing import Dict, List, TYPE_CHECKING, Tuple, Set, Optional, Sequence
2828
from collections import defaultdict
2929
import logging
3030

@@ -45,7 +45,7 @@
4545
class SynchronizerFailure(Exception): pass
4646

4747

48-
def history_status(h):
48+
def history_status(h: Sequence[tuple[str, int]]) -> Optional[str]:
4949
if not h:
5050
return None
5151
status = ''
@@ -96,7 +96,7 @@ async def _add_address(self, addr: str):
9696
finally:
9797
self._adding_addrs.discard(addr) # ok for addr not to be present
9898

99-
async def _on_address_status(self, addr, status):
99+
async def _on_address_status(self, addr: str, status: Optional[str]):
100100
"""Handle the change of the status of an address.
101101
Should remove addr from self._handling_addr_statuses when done.
102102
"""
@@ -160,10 +160,30 @@ def is_up_to_date(self):
160160
and not self._stale_histories
161161
and self.status_queue.empty())
162162

163+
async def _maybe_request_history_for_addr(self, addr: str, *, ann_status: Optional[str]) -> List[dict]:
164+
# First opportunistically try to guess the addr history. Might save us network requests.
165+
old_history = self.adb.db.get_addr_history(addr)
166+
def guess_height(old_height: int) -> int:
167+
if old_height in (0, -1,):
168+
return self.interface.tip # maybe mempool tx got mined just now
169+
return old_height
170+
guessed_history = [(txid, guess_height(old_height)) for (txid, old_height) in old_history]
171+
if history_status(guessed_history) == ann_status:
172+
self.logger.debug(f"managed to guess new history for {addr}. won't call 'blockchain.scripthash.get_history'.")
173+
return [{"height": height, "tx_hash": txid} for (txid, height) in guessed_history]
174+
# request addr history from server
175+
sh = address_to_scripthash(addr)
176+
self._requests_sent += 1
177+
async with self._network_request_semaphore:
178+
result = await self.interface.get_history_for_scripthash(sh)
179+
self._requests_answered += 1
180+
self.logger.info(f"receiving history {addr} {len(result)}")
181+
return result
182+
163183
async def _on_address_status(self, addr, status):
164184
try:
165-
history = self.adb.db.get_addr_history(addr)
166-
if history_status(history) == status:
185+
old_history = self.adb.db.get_addr_history(addr)
186+
if history_status(old_history) == status:
167187
return
168188
# No point in requesting history twice for the same announced status.
169189
# However if we got announced a new status, we should request history again:
@@ -174,12 +194,7 @@ async def _on_address_status(self, addr, status):
174194
self._stale_histories.pop(addr, asyncio.Future()).cancel()
175195
finally:
176196
self._handling_addr_statuses.discard(addr)
177-
h = address_to_scripthash(addr)
178-
self._requests_sent += 1
179-
async with self._network_request_semaphore:
180-
result = await self.interface.get_history_for_scripthash(h)
181-
self._requests_answered += 1
182-
self.logger.info(f"receiving history {addr} {len(result)}")
197+
result = await self._maybe_request_history_for_addr(addr, ann_status=status)
183198
hist = list(map(lambda item: (item['tx_hash'], item['height']), result))
184199
# tx_fees
185200
tx_fees = [(item['tx_hash'], item.get('fee')) for item in result]
@@ -242,7 +257,7 @@ async def _get_transaction(self, tx_hash, *, allow_server_not_finding_tx=False):
242257
raise SynchronizerFailure(f"received tx does not match expected txid ({tx_hash} != {tx.txid()})")
243258
self.requested_tx.remove(tx_hash)
244259
self.adb.receive_tx_callback(tx)
245-
self.logger.info(f"received tx {tx_hash}. bytes: {len(raw_tx)}")
260+
self.logger.info(f"received tx {tx_hash}. bytes-len: {len(raw_tx)//2}")
246261

247262
async def main(self):
248263
self.adb.up_to_date_changed()

tests/test_interface.py

Lines changed: 150 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import collections
3+
from typing import Optional, Sequence, Iterable
34

45
import aiorpcx
56
from aiorpcx import RPCError
@@ -12,8 +13,13 @@
1213
from electrum.simple_config import SimpleConfig
1314
from electrum.transaction import Transaction
1415
from electrum import constants
16+
from electrum.wallet import Abstract_Wallet
17+
from electrum.blockchain import Blockchain
18+
from electrum.bitcoin import script_to_scripthash
19+
from electrum.synchronizer import history_status
1520

1621
from . import ElectrumTestCase
22+
from . import restore_wallet_from_text__for_unittest
1723

1824

1925
class TestServerAddr(ElectrumTestCase):
@@ -86,6 +92,10 @@ async def switch_unwanted_fork_interface(self):
8692
pass
8793
async def switch_lagging_interface(self):
8894
pass
95+
def blockchain(self) -> Blockchain:
96+
return self.interface.blockchain
97+
def get_local_height(self) -> int:
98+
return self.blockchain().height()
8999

90100

91101
# regtest chain:
@@ -106,11 +116,11 @@ async def switch_lagging_interface(self):
106116
}
107117

108118
_active_server_sessions = set()
109-
def _get_active_server_session() -> 'ServerSession':
119+
def _get_active_server_session() -> 'ToyServerSession':
110120
assert 1 == len(_active_server_sessions), len(_active_server_sessions)
111121
return list(_active_server_sessions)[0]
112122

113-
class ServerSession(aiorpcx.RPCSession, Logger):
123+
class ToyServerSession(aiorpcx.RPCSession, Logger):
114124

115125
def __init__(self, *args, **kwargs):
116126
aiorpcx.RPCSession.__init__(self, *args, **kwargs)
@@ -120,6 +130,12 @@ def __init__(self, *args, **kwargs):
120130
self.txs = {
121131
"bdae818ad3c1f261317738ae9284159bf54874356f186dbc7afd631dc1527fcb": bfh("020000000001010000000000000000000000000000000000000000000000000000000000000000ffffffff025100ffffffff0200f2052a010000001600140297bde2689a3c79ffe050583b62f86f2d9dae540000000000000000266a24aa21a9ede2f61c3f71d1defd3fa999dfa36953755c690689799962b48bebd836974e8cf90120000000000000000000000000000000000000000000000000000000000000000000000000"),
122132
} # type: dict[str, bytes]
133+
self.txid_to_block_height = collections.defaultdict(int) # type: dict[str, int]
134+
self.subbed_headers = False
135+
self.notified_height = None # type: Optional[int]
136+
self.subbed_scripthashes = set() # type: set[str]
137+
self.sh_to_funding_txids = collections.defaultdict(set) # type: dict[str, set[str]]
138+
self.sh_to_spending_txids = collections.defaultdict(set) # type: dict[str, set[str]]
123139
self._method_counts = collections.defaultdict(int) # type: dict[str, int]
124140
_active_server_sessions.add(self)
125141

@@ -136,8 +152,11 @@ async def handle_request(self, request):
136152
'blockchain.headers.subscribe': self._handle_headers_subscribe,
137153
'blockchain.block.header': self._handle_block_header,
138154
'blockchain.block.headers': self._handle_block_headers,
155+
'blockchain.scripthash.subscribe': self._handle_scripthash_subscribe,
156+
'blockchain.scripthash.get_history': self._handle_scripthash_get_history,
139157
'blockchain.transaction.get': self._handle_transaction_get,
140158
'blockchain.transaction.broadcast': self._handle_transaction_broadcast,
159+
'blockchain.transaction.get_merkle': self._handle_transaction_get_merkle,
141160
'server.ping': self._handle_ping,
142161
}
143162
handler = handlers.get(request.method)
@@ -162,9 +181,13 @@ async def _handle_server_features(self) -> dict:
162181
async def _handle_estimatefee(self, number, mode=None):
163182
return 1000
164183

165-
async def _handle_headers_subscribe(self):
184+
def _get_headersub_result(self):
166185
return {'hex': BLOCK_HEADERS[self.cur_height].hex(), 'height': self.cur_height}
167186

187+
async def _handle_headers_subscribe(self):
188+
self.subbed_headers = True
189+
return self._get_headersub_result()
190+
168191
async def _handle_block_header(self, height):
169192
return BLOCK_HEADERS[height].hex()
170193

@@ -186,10 +209,97 @@ async def _handle_transaction_get(self, tx_hash: str, verbose=False):
186209
raise RPCError(DAEMON_ERROR, f'daemon error: unknown txid={tx_hash}')
187210
return rawtx.hex()
188211

189-
async def _handle_transaction_broadcast(self, raw_tx: str):
212+
async def _handle_transaction_get_merkle(self, tx_hash: str, height: int) -> dict:
213+
# Fake stuff. Client will ignore it due to config.NETWORK_SKIPMERKLECHECK
214+
return {
215+
"merkle":
216+
[
217+
"713d6c7e6ce7bbea708d61162231eaa8ecb31c4c5dd84f81c20409a90069cb24",
218+
"03dbaec78d4a52fbaf3c7aa5d3fccd9d8654f323940716ddf5ee2e4bda458fde",
219+
"e670224b23f156c27993ac3071940c0ff865b812e21e0a162fe7a005d6e57851",
220+
"369a1619a67c3108a8850118602e3669455c70cdcdb89248b64cc6325575b885",
221+
"4756688678644dcb27d62931f04013254a62aeee5dec139d1aac9f7b1f318112",
222+
"7b97e73abc043836fd890555bfce54757d387943a6860e5450525e8e9ab46be5",
223+
"61505055e8b639b7c64fd58bce6fc5c2378b92e025a02583303f69930091b1c3",
224+
"27a654ff1895385ac14a574a0415d3bbba9ec23a8774f22ec20d53dd0b5386ff",
225+
"5312ed87933075e60a9511857d23d460a085f3b6e9e5e565ad2443d223cfccdc",
226+
"94f60b14a9f106440a197054936e6fb92abbd69d6059b38fdf79b33fc864fca0",
227+
"2d64851151550e8c4d337f335ee28874401d55b358a66f1bafab2c3e9f48773d"
228+
],
229+
"block_height": height,
230+
"pos": 710,
231+
}
232+
233+
async def _handle_transaction_broadcast(self, raw_tx: str) -> str:
190234
tx = Transaction(raw_tx)
191-
self.txs[tx.txid()] = bfh(raw_tx)
192-
return tx.txid()
235+
txid = tx.txid()
236+
self.txs[txid] = bfh(raw_tx)
237+
touched_sh = await self._process_added_tx(txid=txid)
238+
if touched_sh:
239+
await self._send_notifications(touched_sh=touched_sh)
240+
return txid
241+
242+
async def _process_added_tx(self, *, txid: str) -> set[str]:
243+
"""Returns touched scripthashes."""
244+
tx = Transaction(self.txs[txid])
245+
touched_sh = set()
246+
# update sh_to_funding_txids
247+
for txout in tx.outputs():
248+
sh = script_to_scripthash(txout.scriptpubkey)
249+
self.sh_to_funding_txids[sh].add(txid)
250+
touched_sh.add(sh)
251+
# update sh_to_spending_txids
252+
for txin in tx.inputs():
253+
if parent_tx_raw := self.txs.get(txin.prevout.txid.hex()):
254+
parent_tx = Transaction(parent_tx_raw)
255+
ptxout = parent_tx.outputs()[txin.prevout.out_idx]
256+
sh = script_to_scripthash(ptxout.scriptpubkey)
257+
self.sh_to_spending_txids[sh].add(txid)
258+
touched_sh.add(sh)
259+
return touched_sh
260+
261+
async def _handle_scripthash_subscribe(self, sh: str) -> Optional[str]:
262+
self.subbed_scripthashes.add(sh)
263+
hist = self._calc_sh_history(sh)
264+
return history_status(hist)
265+
266+
async def _handle_scripthash_get_history(self, sh: str) -> Sequence[dict]:
267+
hist_tuples = self._calc_sh_history(sh)
268+
hist_dicts = [{"height": height, "tx_hash": txid} for (txid, height) in hist_tuples]
269+
for hist_dict in hist_dicts: # add "fee" key for mempool txs
270+
if hist_dict["height"] in (0, -1,):
271+
hist_dict["fee"] = 0
272+
return hist_dicts
273+
274+
def _calc_sh_history(self, sh: str) -> Sequence[tuple[str, int]]:
275+
txids = self.sh_to_funding_txids[sh] | self.sh_to_spending_txids[sh]
276+
hist = []
277+
for txid in txids:
278+
bh = self.txid_to_block_height[txid]
279+
hist.append((txid, bh))
280+
hist.sort(key=lambda x: x[1]) # FIXME put mempool txs last
281+
return hist
282+
283+
async def _send_notifications(self, *, touched_sh: Iterable[str], height_changed: bool = False) -> None:
284+
if height_changed and self.subbed_headers and self.notified_height != self.cur_height:
285+
self.notified_height = self.cur_height
286+
args = (self._get_headersub_result(),)
287+
await self.send_notification('blockchain.headers.subscribe', args)
288+
touched_sh = set(sh for sh in touched_sh if sh in self.subbed_scripthashes)
289+
for sh in touched_sh:
290+
hist = self._calc_sh_history(sh)
291+
args = (sh, history_status(hist))
292+
await self.send_notification("blockchain.scripthash.subscribe", args)
293+
294+
async def mine_block(self, *, txids_mined: Iterable[str] = None):
295+
if txids_mined is None:
296+
txids_mined = []
297+
self.cur_height += 1
298+
touched_sh = set()
299+
for txid in txids_mined:
300+
self.txid_to_block_height[txid] = self.cur_height
301+
touched_sh |= await self._process_added_tx(txid=txid)
302+
await self._send_notifications(touched_sh=touched_sh, height_changed=True)
193303

194304

195305
class TestInterface(ElectrumTestCase):
@@ -198,6 +308,7 @@ class TestInterface(ElectrumTestCase):
198308
def setUp(self):
199309
super().setUp()
200310
self.config = SimpleConfig({'electrum_path': self.electrum_path})
311+
self.config.NETWORK_SKIPMERKLECHECK = True
201312
self._orig_WAIT_FOR_BUFFER_GROWTH_SECONDS = PaddedRSTransport.WAIT_FOR_BUFFER_GROWTH_SECONDS
202313
PaddedRSTransport.WAIT_FOR_BUFFER_GROWTH_SECONDS = 0
203314

@@ -207,7 +318,7 @@ def tearDown(self):
207318

208319
async def asyncSetUp(self):
209320
await super().asyncSetUp()
210-
self._server: asyncio.base_events.Server = await aiorpcx.serve_rs(ServerSession, "127.0.0.1")
321+
self._server: asyncio.base_events.Server = await aiorpcx.serve_rs(ToyServerSession, "127.0.0.1")
211322
server_socket_addr = self._server.sockets[0].getsockname()
212323
self._server_port = server_socket_addr[1]
213324
self.network = MockNetwork(config=self.config)
@@ -255,3 +366,35 @@ async def test_transaction_broadcast(self):
255366
rawtx2 = await interface.get_transaction(tx.txid())
256367
self.assertEqual(rawtx1, rawtx2)
257368
self.assertEqual(_get_active_server_session()._method_counts["blockchain.transaction.get"], 0)
369+
370+
async def test_dont_request_gethistory_if_status_change_results_from_mempool_txs_simply_getting_mined(self):
371+
"""After a new block is mined, we recv "blockchain.scripthash.subscribe" notifs.
372+
We opportunistically guess the scripthash status changed purely because touching mempool txs just got mined.
373+
If the guess is correct, we won't call the "blockchain.scripthash.get_history" RPC.
374+
"""
375+
interface = await self._start_iface_and_wait_for_sync()
376+
w1 = restore_wallet_from_text__for_unittest("9dk", path=None, config=self.config)['wallet'] # type: Abstract_Wallet
377+
w1.start_network(self.network)
378+
await w1.up_to_date_changed_event.wait()
379+
self.assertEqual(_get_active_server_session()._method_counts["blockchain.scripthash.get_history"], 0)
380+
# fund w1 (in mempool)
381+
funding_tx = "01000000000101e855888b77b1688d08985b863bfe85b354049b4eba923db9b5cf37089975d5d10000000000fdffffff0280969800000000001600140297bde2689a3c79ffe050583b62f86f2d9dae5460abe9000000000016001472df47551b6e7e0c8428814d2e572bc5ac773dda024730440220383efa2f0f5b87f8ce5d6b6eaf48cba03bf522b23fbb23b2ac54ff9d9a8f6a8802206f67d1f909f3c7a22ac0308ac4c19853ffca3a9317e1d7e0c88cc3a86853aaac0121035061949222555a0df490978fe6e7ebbaa96332ecb5c266918fd800c0eef736e7358d1400"
382+
funding_txid = await _get_active_server_session()._handle_transaction_broadcast(funding_tx)
383+
await w1.up_to_date_changed_event.wait()
384+
while not w1.is_up_to_date():
385+
await w1.up_to_date_changed_event.wait()
386+
self.assertEqual(_get_active_server_session()._method_counts["blockchain.scripthash.get_history"], 1)
387+
self.assertEqual(
388+
w1.adb.get_address_history("bcrt1qq2tmmcngng78nllq2pvrkchcdukemtj5jnxz44"),
389+
{funding_txid: 0})
390+
# mine funding tx
391+
await _get_active_server_session().mine_block(txids_mined=[funding_txid])
392+
await w1.up_to_date_changed_event.wait()
393+
while not w1.is_up_to_date():
394+
await w1.up_to_date_changed_event.wait()
395+
# see if we managed to guess new history, and hence did not need to call get_history RPC
396+
self.assertEqual(_get_active_server_session()._method_counts["blockchain.scripthash.get_history"], 1)
397+
self.assertEqual(
398+
w1.adb.get_address_history("bcrt1qq2tmmcngng78nllq2pvrkchcdukemtj5jnxz44"),
399+
{funding_txid: 7})
400+

0 commit comments

Comments
 (0)