5858from synapse .util import json_decoder , json_encoder
5959from synapse .util .caches .descriptors import cached , cachedList
6060from synapse .util .caches .lrucache import LruCache
61- from synapse .util .caches .stream_change_cache import StreamChangeCache
61+ from synapse .util .caches .stream_change_cache import (
62+ AllEntitiesChangedResult ,
63+ StreamChangeCache ,
64+ )
6265from synapse .util .cancellation import cancellable
6366from synapse .util .iterutils import batch_iter
6467from synapse .util .stringutils import shortstr
@@ -799,18 +802,66 @@ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]
799802 def get_cached_device_list_changes (
800803 self ,
801804 from_key : int ,
802- ) -> Optional [ List [ str ]] :
805+ ) -> AllEntitiesChangedResult :
803806 """Get set of users whose devices have changed since `from_key`, or None
804807 if that information is not in our cache.
805808 """
806809
807810 return self ._device_list_stream_cache .get_all_entities_changed (from_key )
808811
812+ @cancellable
813+ async def get_all_devices_changed (
814+ self ,
815+ from_key : int ,
816+ to_key : int ,
817+ ) -> Set [str ]:
818+ """Get all users whose devices have changed in the given range.
819+
820+ Args:
821+ from_key: The minimum device lists stream token to query device list
822+ changes for, exclusive.
823+ to_key: The maximum device lists stream token to query device list
824+ changes for, inclusive.
825+
826+ Returns:
827+ The set of user_ids whose devices have changed since `from_key`
828+ (exclusive) until `to_key` (inclusive).
829+ """
830+
831+ result = self ._device_list_stream_cache .get_all_entities_changed (from_key )
832+
833+ if result .hit :
834+ # We know which users might have changed devices.
835+ if not result .entities :
836+ # If no users then we can return early.
837+ return set ()
838+
839+ # Otherwise we need to filter down the list
840+ return await self .get_users_whose_devices_changed (
841+ from_key , result .entities , to_key
842+ )
843+
844+ # If the cache didn't tell us anything, we just need to query the full
845+ # range.
846+ sql = """
847+ SELECT DISTINCT user_id FROM device_lists_stream
848+ WHERE ? < stream_id AND stream_id <= ?
849+ """
850+
851+ rows = await self .db_pool .execute (
852+ "get_all_devices_changed" ,
853+ None ,
854+ sql ,
855+ from_key ,
856+ to_key ,
857+ )
858+ return {u for u , in rows }
859+
809860 @cancellable
810861 async def get_users_whose_devices_changed (
811862 self ,
812863 from_key : int ,
813- user_ids : Optional [ Collection [str ]] = None ,
864+ user_ids : Collection [str ],
814865 to_key : Optional [int ] = None ,
815866 ) -> Set [str ]:
816867 """Get set of users whose devices have changed since `from_key` that
@@ -830,52 +881,32 @@ async def get_users_whose_devices_changed(
830881 """
831882 # Get set of users who *may* have changed. Users not in the returned
832883 # list have definitely not changed.
833- user_ids_to_check : Optional [Collection [str ]]
834- if user_ids is None :
835- # Get set of all users that have had device list changes since 'from_key'
836- user_ids_to_check = self ._device_list_stream_cache .get_all_entities_changed (
837- from_key
838- )
839- else :
840- # The same as above, but filter results to only those users in 'user_ids'
841- user_ids_to_check = self ._device_list_stream_cache .get_entities_changed (
842- user_ids , from_key
843- )
884+ user_ids_to_check = self ._device_list_stream_cache .get_entities_changed (
885+ user_ids , from_key
886+ )
844887
845888 # If an empty set was returned, there's nothing to do.
846- if user_ids_to_check is not None and not user_ids_to_check :
889+ if not user_ids_to_check :
847890 return set ()
848891
849- def _get_users_whose_devices_changed_txn (txn : LoggingTransaction ) -> Set [str ]:
850- stream_id_where_clause = "stream_id > ?"
851- sql_args = [from_key ]
852-
853- if to_key :
854- stream_id_where_clause += " AND stream_id <= ?"
855- sql_args .append (to_key )
892+ if to_key is None :
893+ to_key = self ._device_list_id_gen .get_current_token ()
856894
857- sql = f"""
895+ def _get_users_whose_devices_changed_txn (txn : LoggingTransaction ) -> Set [str ]:
896+ sql = """
858897 SELECT DISTINCT user_id FROM device_lists_stream
859- WHERE { stream_id_where_clause }
898+ WHERE ? < stream_id AND stream_id <= ? AND %s
860899 """
861900
862- # If the stream change cache gave us no information, fetch *all*
863- # users between the stream IDs.
864- if user_ids_to_check is None :
865- txn .execute (sql , sql_args )
866- return {user_id for user_id , in txn }
901+ changes : Set [str ] = set ()
867902
868- # Otherwise, fetch changes for the given users.
869- else :
870- changes : Set [str ] = set ()
871-
872- # Query device changes with a batch of users at a time
873- for chunk in batch_iter (user_ids_to_check , 100 ):
874- clause , args = make_in_list_sql_clause (
875- txn .database_engine , "user_id" , chunk
876- )
877- txn .execute (sql + " AND " + clause , sql_args + args )
878- changes .update (user_id for user_id , in txn )
903+ # Query device changes with a batch of users at a time
904+ for chunk in batch_iter (user_ids_to_check , 100 ):
905+ clause , args = make_in_list_sql_clause (
906+ txn .database_engine , "user_id" , chunk
907+ )
908+ txn .execute (sql % (clause ,), [from_key , to_key ] + args )
909+ changes .update (user_id for user_id , in txn )
879910
880911 return changes
881912
0 commit comments