Skip to content

Commit cee9445

Browse files
authored
Better return type for get_all_entities_changed (#14604)
Help callers from using the return value incorrectly by ensuring that callers explicitly check if there was a cache hit or not.
1 parent 6a8310f commit cee9445

File tree

8 files changed

+138
-76
lines changed

8 files changed

+138
-76
lines changed

changelog.d/14604.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances.

synapse/handlers/appservice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,8 +615,8 @@ async def _get_device_list_summary(
615615
)
616616

617617
# Fetch the users who have modified their device list since then.
618-
users_with_changed_device_lists = (
619-
await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
618+
users_with_changed_device_lists = await self.store.get_all_devices_changed(
619+
from_key, to_key=new_key
620620
)
621621

622622
# Filter out any users the application service is not interested in

synapse/handlers/presence.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,10 +1692,12 @@ async def get_new_events(
16921692

16931693
if from_key is not None:
16941694
# First get all users that have had a presence update
1695-
updated_users = stream_change_cache.get_all_entities_changed(from_key)
1695+
result = stream_change_cache.get_all_entities_changed(from_key)
16961696

16971697
# Cross-reference users we're interested in with those that have had updates.
1698-
if updated_users is not None:
1698+
if result.hit:
1699+
updated_users = result.entities
1700+
16991701
# If we have the full list of changes for presence we can
17001702
# simply check which ones share a room with the user.
17011703
get_updates_counter.labels("stream").inc()
@@ -1767,9 +1769,9 @@ async def _filter_all_presence_updates_for_user(
17671769
updated_users = None
17681770
if from_key:
17691771
# Only return updates since the last sync
1770-
updated_users = self.store.presence_stream_cache.get_all_entities_changed(
1771-
from_key
1772-
)
1772+
result = self.store.presence_stream_cache.get_all_entities_changed(from_key)
1773+
if result.hit:
1774+
updated_users = result.entities
17731775

17741776
if updated_users is not None:
17751777
# Get the actual presence update for each change

synapse/handlers/sync.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,10 +1528,12 @@ async def _generate_sync_entry_for_device_list(
15281528
#
15291529
# If we don't have that info cached then we get all the users that
15301530
# share a room with our user and check if those users have changed.
1531-
changed_users = self.store.get_cached_device_list_changes(
1531+
cache_result = self.store.get_cached_device_list_changes(
15321532
since_token.device_list_key
15331533
)
1534-
if changed_users is not None:
1534+
if cache_result.hit:
1535+
changed_users = cache_result.entities
1536+
15351537
result = await self.store.get_rooms_for_users(changed_users)
15361538

15371539
for changed_user_id, entries in result.items():

synapse/handlers/typing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -420,11 +420,11 @@ async def get_all_typing_updates(
420420
if last_id == current_id:
421421
return [], current_id, False
422422

423-
changed_rooms: Optional[
424-
Iterable[str]
425-
] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
423+
result = self._typing_stream_change_cache.get_all_entities_changed(last_id)
426424

427-
if changed_rooms is None:
425+
if result.hit:
426+
changed_rooms: Iterable[str] = result.entities
427+
else:
428428
changed_rooms = self._room_serials
429429

430430
rows = []

synapse/storage/databases/main/devices.py

Lines changed: 71 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@
5858
from synapse.util import json_decoder, json_encoder
5959
from synapse.util.caches.descriptors import cached, cachedList
6060
from synapse.util.caches.lrucache import LruCache
61-
from synapse.util.caches.stream_change_cache import StreamChangeCache
61+
from synapse.util.caches.stream_change_cache import (
62+
AllEntitiesChangedResult,
63+
StreamChangeCache,
64+
)
6265
from synapse.util.cancellation import cancellable
6366
from synapse.util.iterutils import batch_iter
6467
from synapse.util.stringutils import shortstr
@@ -799,18 +802,66 @@ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]
799802
def get_cached_device_list_changes(
800803
self,
801804
from_key: int,
802-
) -> Optional[List[str]]:
805+
) -> AllEntitiesChangedResult:
803806
"""Get set of users whose devices have changed since `from_key`, or None
804807
if that information is not in our cache.
805808
"""
806809

807810
return self._device_list_stream_cache.get_all_entities_changed(from_key)
808811

812+
@cancellable
813+
async def get_all_devices_changed(
814+
self,
815+
from_key: int,
816+
to_key: int,
817+
) -> Set[str]:
818+
"""Get all users whose devices have changed in the given range.
819+
820+
Args:
821+
from_key: The minimum device lists stream token to query device list
822+
changes for, exclusive.
823+
to_key: The maximum device lists stream token to query device list
824+
changes for, inclusive.
825+
826+
Returns:
827+
The set of user_ids whose devices have changed since `from_key`
828+
(exclusive) until `to_key` (inclusive).
829+
"""
830+
831+
result = self._device_list_stream_cache.get_all_entities_changed(from_key)
832+
833+
if result.hit:
834+
# We know which users might have changed devices.
835+
if not result.entities:
836+
# If no users then we can return early.
837+
return set()
838+
839+
# Otherwise we need to filter down the list
840+
return await self.get_users_whose_devices_changed(
841+
from_key, result.entities, to_key
842+
)
843+
844+
# If the cache didn't tell us anything, we just need to query the full
845+
# range.
846+
sql = """
847+
SELECT DISTINCT user_id FROM device_lists_stream
848+
WHERE ? < stream_id AND stream_id <= ?
849+
"""
850+
851+
rows = await self.db_pool.execute(
852+
"get_all_devices_changed",
853+
None,
854+
sql,
855+
from_key,
856+
to_key,
857+
)
858+
return {u for u, in rows}
859+
809860
@cancellable
810861
async def get_users_whose_devices_changed(
811862
self,
812863
from_key: int,
813-
user_ids: Optional[Collection[str]] = None,
864+
user_ids: Collection[str],
814865
to_key: Optional[int] = None,
815866
) -> Set[str]:
816867
"""Get set of users whose devices have changed since `from_key` that
@@ -830,52 +881,32 @@ async def get_users_whose_devices_changed(
830881
"""
831882
# Get set of users who *may* have changed. Users not in the returned
832883
# list have definitely not changed.
833-
user_ids_to_check: Optional[Collection[str]]
834-
if user_ids is None:
835-
# Get set of all users that have had device list changes since 'from_key'
836-
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
837-
from_key
838-
)
839-
else:
840-
# The same as above, but filter results to only those users in 'user_ids'
841-
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
842-
user_ids, from_key
843-
)
884+
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
885+
user_ids, from_key
886+
)
844887

845888
# If an empty set was returned, there's nothing to do.
846-
if user_ids_to_check is not None and not user_ids_to_check:
889+
if not user_ids_to_check:
847890
return set()
848891

849-
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
850-
stream_id_where_clause = "stream_id > ?"
851-
sql_args = [from_key]
852-
853-
if to_key:
854-
stream_id_where_clause += " AND stream_id <= ?"
855-
sql_args.append(to_key)
892+
if to_key is None:
893+
to_key = self._device_list_id_gen.get_current_token()
856894

857-
sql = f"""
895+
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
896+
sql = """
858897
SELECT DISTINCT user_id FROM device_lists_stream
859-
WHERE {stream_id_where_clause}
898+
WHERE ? < stream_id AND stream_id <= ? AND %s
860899
"""
861900

862-
# If the stream change cache gave us no information, fetch *all*
863-
# users between the stream IDs.
864-
if user_ids_to_check is None:
865-
txn.execute(sql, sql_args)
866-
return {user_id for user_id, in txn}
901+
changes: Set[str] = set()
867902

868-
# Otherwise, fetch changes for the given users.
869-
else:
870-
changes: Set[str] = set()
871-
872-
# Query device changes with a batch of users at a time
873-
for chunk in batch_iter(user_ids_to_check, 100):
874-
clause, args = make_in_list_sql_clause(
875-
txn.database_engine, "user_id", chunk
876-
)
877-
txn.execute(sql + " AND " + clause, sql_args + args)
878-
changes.update(user_id for user_id, in txn)
903+
# Query device changes with a batch of users at a time
904+
for chunk in batch_iter(user_ids_to_check, 100):
905+
clause, args = make_in_list_sql_clause(
906+
txn.database_engine, "user_id", chunk
907+
)
908+
txn.execute(sql % (clause,), [from_key, to_key] + args)
909+
changes.update(user_id for user_id, in txn)
879910

880911
return changes
881912

synapse/util/caches/stream_change_cache.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import math
1717
from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union
1818

19+
import attr
1920
from sortedcontainers import SortedDict
2021

2122
from synapse.util import caches
@@ -26,6 +27,29 @@
2627
EntityType = str
2728

2829

30+
@attr.s(auto_attribs=True, frozen=True, slots=True)
31+
class AllEntitiesChangedResult:
32+
"""Return type of `get_all_entities_changed`.
33+
34+
Callers must check that there was a cache hit, via `result.hit`, before
35+
using the entities in `result.entities`.
36+
37+
This specifically does *not* implement helpers such as `__bool__` to ensure
38+
that callers do the correct checks.
39+
"""
40+
41+
_entities: Optional[List[EntityType]]
42+
43+
@property
44+
def hit(self) -> bool:
45+
return self._entities is not None
46+
47+
@property
48+
def entities(self) -> List[EntityType]:
49+
assert self._entities is not None
50+
return self._entities
51+
52+
2953
class StreamChangeCache:
3054
"""
3155
Keeps track of the stream positions of the latest change in a set of entities.
@@ -153,19 +177,19 @@ def get_entities_changed(
153177
This will be all entities if the given stream position is at or earlier
154178
than the earliest known stream position.
155179
"""
156-
changed_entities = self.get_all_entities_changed(stream_pos)
157-
if changed_entities is not None:
180+
cache_result = self.get_all_entities_changed(stream_pos)
181+
if cache_result.hit:
158182
# We now do an intersection, trying to do so in the most efficient
159183
# way possible (some of these sets are *large*). First check in the
160184
# given iterable is already a set that we can reuse, otherwise we
161185
# create a set of the *smallest* of the two iterables and call
162186
# `intersection(..)` on it (this can be twice as fast as the reverse).
163187
if isinstance(entities, (set, frozenset)):
164-
result = entities.intersection(changed_entities)
165-
elif len(changed_entities) < len(entities):
166-
result = set(changed_entities).intersection(entities)
188+
result = entities.intersection(cache_result.entities)
189+
elif len(cache_result.entities) < len(entities):
190+
result = set(cache_result.entities).intersection(entities)
167191
else:
168-
result = set(entities).intersection(changed_entities)
192+
result = set(entities).intersection(cache_result.entities)
169193
self.metrics.inc_hits()
170194
else:
171195
result = set(entities)
@@ -202,36 +226,34 @@ def has_any_entity_changed(self, stream_pos: int) -> bool:
202226
self.metrics.inc_hits()
203227
return stream_pos < self._cache.peekitem()[0]
204228

205-
def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
229+
def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult:
206230
"""
207231
Returns all entities that have had changes after the given position.
208232
209-
If the stream change cache does not go far enough back, i.e. the position
210-
is too old, it will return None.
233+
If the stream change cache does not go far enough back, i.e. the
234+
position is too old, it will return None.
211235
212236
Returns the entities in the order that they were changed.
213237
214238
Args:
215239
stream_pos: The stream position to check for changes after.
216240
217241
Return:
218-
Entities which have changed after the given stream position.
219-
220-
None if the given stream position is at or earlier than the earliest
221-
known stream position.
242+
A class indicating if we have the requested data cached, and if so
243+
includes the entities in the order they were changed.
222244
"""
223245
assert isinstance(stream_pos, int)
224246

225247
# _cache is not valid at or before the earliest known stream position, so
226248
# return None to mark that it is unknown if an entity has changed.
227249
if stream_pos <= self._earliest_known_stream_pos:
228-
return None
250+
return AllEntitiesChangedResult(None)
229251

230252
changed_entities: List[EntityType] = []
231253

232254
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
233255
changed_entities.extend(self._cache[k])
234-
return changed_entities
256+
return AllEntitiesChangedResult(changed_entities)
235257

236258
def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
237259
"""

tests/util/test_stream_change_cache.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,19 +73,21 @@ def test_entity_has_changed_pops_off_start(self) -> None:
7373
# The oldest item has been popped off
7474
self.assertTrue("[email protected]" not in cache._entity_to_key)
7575

76-
self.assertEqual(cache.get_all_entities_changed(3), ["[email protected]"])
77-
self.assertIsNone(cache.get_all_entities_changed(2))
76+
self.assertEqual(
77+
cache.get_all_entities_changed(3).entities, ["[email protected]"]
78+
)
79+
self.assertFalse(cache.get_all_entities_changed(2).hit)
7880

7981
# If we update an existing entity, it keeps the two existing entities
8082
cache.entity_has_changed("[email protected]", 5)
8183
self.assertEqual(
8284
{"[email protected]", "[email protected]"}, set(cache._entity_to_key)
8385
)
8486
self.assertEqual(
85-
cache.get_all_entities_changed(3),
87+
cache.get_all_entities_changed(3).entities,
8688
8789
)
88-
self.assertIsNone(cache.get_all_entities_changed(2))
90+
self.assertFalse(cache.get_all_entities_changed(2).hit)
8991

9092
def test_get_all_entities_changed(self) -> None:
9193
"""
@@ -105,10 +107,12 @@ def test_get_all_entities_changed(self) -> None:
105107
# Results are ordered so either of these are valid.
106108
107109
108-
self.assertTrue(r == ok1 or r == ok2)
110+
self.assertTrue(r.entities == ok1 or r.entities == ok2)
109111

110-
self.assertEqual(cache.get_all_entities_changed(3), ["[email protected]"])
111-
self.assertEqual(cache.get_all_entities_changed(1), None)
112+
self.assertEqual(
113+
cache.get_all_entities_changed(3).entities, ["[email protected]"]
114+
)
115+
self.assertFalse(cache.get_all_entities_changed(1).hit)
112116

113117
# ... later, things gest more updates
114118
cache.entity_has_changed("[email protected]", 5)
@@ -128,7 +132,7 @@ def test_get_all_entities_changed(self) -> None:
128132
129133
]
130134
r = cache.get_all_entities_changed(3)
131-
self.assertTrue(r == ok1 or r == ok2)
135+
self.assertTrue(r.entities == ok1 or r.entities == ok2)
132136

133137
def test_has_any_entity_changed(self) -> None:
134138
"""

0 commit comments

Comments
 (0)