Skip to content

Commit 2566ab9

Browse files
committed
Serialize calls to _get_room_for_address to prevent races
Also, remove _cachegetter and replace it with explicit caches
1 parent ac28cda commit 2566ab9

File tree

1 file changed

+35
-64
lines changed

1 file changed

+35
-64
lines changed

raiden/network/transport/matrix.py

Lines changed: 35 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
from operator import attrgetter, itemgetter
77
from random import Random
88
from urllib.parse import urlparse
9-
from weakref import WeakKeyDictionary, WeakValueDictionary
109

1110
import gevent
1211
import structlog
13-
from cachetools import TTLCache, cachedmethod
12+
from cachetools import LRUCache, TTLCache, cached, cachedmethod
1413
from eth_utils import (
1514
decode_hex,
1615
encode_hex,
@@ -19,6 +18,7 @@
1918
to_checksum_address,
2019
to_normalized_address,
2120
)
21+
from gevent.lock import Semaphore
2222
from matrix_client.errors import MatrixError, MatrixRequestError
2323
from matrix_client.user import User
2424

@@ -63,13 +63,10 @@
6363
Dict,
6464
Iterable,
6565
List,
66-
Mapping,
6766
NewType,
6867
Optional,
6968
Set,
7069
Tuple,
71-
Type,
72-
TypeVar,
7370
Union,
7471
)
7572
from raiden_libs.exceptions import InvalidSignature
@@ -78,40 +75,9 @@
7875

7976
log = structlog.get_logger(__name__)
8077

81-
_CT = TypeVar('CT') # class type
82-
_CIT = Union[_CT, Type[_CT]] # class or instance type
83-
_RT = TypeVar('RT') # return type
84-
_CacheT = Mapping[Tuple, _RT] # cache type (mapping)
8578
_RoomID = NewType('RoomID', str)
8679

8780

88-
def _cachegetter(
89-
attr: str,
90-
cachefactory: Callable[[], _CacheT] = WeakKeyDictionary, # WeakKeyDict best for properties
91-
) -> Callable[[_CIT], _CacheT]:
92-
"""Returns a safer attrgetter which constructs the missing object with cachefactory
93-
94-
May be used for normal methods, classmethods and properties, as default
95-
factory is a WeakKeyDictionary (good for storing weak-refs for self or cls).
96-
It may also safely be used with staticmethods, if first parameter is an object
97-
on which the cache will be stored.
98-
Better when used with key getter. If it's a tuple, you should use e.g. cachefactory=dict
99-
Example usage with cachetools.cachedmethod:
100-
class Foo:
101-
@property
102-
@cachedmethod(_cachegetter("__bar_cache"))
103-
def bar(self) -> _RT:
104-
return 2+3
105-
"""
106-
def cachegetter(cls_or_obj: _CIT) -> _CacheT:
107-
cache = getattr(cls_or_obj, attr, None)
108-
if cache is None:
109-
cache = cachefactory()
110-
setattr(cls_or_obj, attr, cache)
111-
return cache
112-
return cachegetter
113-
114-
11581
class UserPresence(Enum):
11682
ONLINE = 'online'
11783
UNAVAILABLE = 'unavailable'
@@ -126,6 +92,8 @@ class MatrixTransport(Runnable):
12692
_room_prefix = 'raiden'
12793
_room_sep = '_'
12894
_userid_re = re.compile(r'^@(0x[0-9a-f]{40})(?:\.[0-9a-f]{8})?(?::.+)?$')
95+
_addresses_cache = LRUCache(512) # deterministic thus shared
96+
log = log
12997

13098
def __init__(self, config: dict):
13199
super().__init__()
@@ -180,11 +148,14 @@ def _http_retry_delay() -> Iterable[float]:
180148

181149
self._stop_event = gevent.event.Event()
182150
self._stop_event.set()
183-
self._health_semaphore = gevent.lock.Semaphore()
184151

185152
self._client.add_invite_listener(self._handle_invite)
186153
self._client.add_presence_listener(self._handle_presence_change)
187154

155+
self._messages_cache = TTLCache(32, 4)
156+
self._health_lock = Semaphore()
157+
self._getroom_lock = Semaphore()
158+
188159
def start(
189160
self,
190161
raiden_service: RaidenService,
@@ -195,6 +166,8 @@ def start(
195166
self._raiden_service = raiden_service
196167

197168
self._login_or_register()
169+
self.log = log.bind(current_user=self._user_id, node=pex(self._raiden_service.address))
170+
198171
if self._client._handle_thread:
199172
# wait on _handle_thread for initial sync
200173
# this is needed so the rooms are populated before we _inventory_rooms
@@ -242,6 +215,7 @@ def stop(self):
242215
# wait own greenlets, no need to get on them, exceptions should be raised in _run()
243216
gevent.wait(self.greenlets)
244217
self._client.logout()
218+
del self.log
245219
# parent may want to call get() after stop(), to ensure _run errors are re-raised
246220
# we don't call it here to avoid deadlock when self crashes and calls stop() on finally
247221

@@ -262,7 +236,7 @@ def start_health_check(self, node_address):
262236
if self._stop_event.ready():
263237
return
264238

265-
with self._health_semaphore:
239+
with self._health_lock:
266240
if node_address in self._address_to_userids:
267241
return # already healthchecked
268242

@@ -337,13 +311,6 @@ def _queueids_to_queues(self) -> QueueIdsToQueues:
337311
def _user_id(self) -> Optional[str]:
338312
return getattr(self, '_client', None) and getattr(self._client, 'user_id', None)
339313

340-
@property
341-
@cachedmethod(_cachegetter('__log_cache', dict), key=attrgetter('_user_id'))
342-
def log(self):
343-
if not self._user_id:
344-
return log
345-
return log.bind(current_user=self._user_id, node=pex(self._raiden_service.address))
346-
347314
@property
348315
def _network_name(self) -> str:
349316
return ID_TO_NETWORKNAME.get(
@@ -556,7 +523,7 @@ def _handle_invite(self, room_id: _RoomID, state: dict):
556523
)
557524

558525
@cachedmethod(
559-
_cachegetter('__messages_cache', lambda: TTLCache(32, 4)),
526+
attrgetter('_messages_cache'),
560527
key=lambda _, room, event: (room.room_id, event['type'], event['content'].get('body')),
561528
)
562529
def _handle_message(self, room, event) -> bool:
@@ -828,8 +795,9 @@ def _send_immediate(
828795

829796
self._send_raw(receiver_address, data)
830797

831-
def _send_raw(self, receiver_address, data):
832-
room = self._get_room_for_address(receiver_address)
798+
def _send_raw(self, receiver_address: Address, data: str):
799+
with self._getroom_lock:
800+
room = self._get_room_for_address(receiver_address)
833801
if not room:
834802
return
835803
self.log.debug('SEND', receiver=pex(receiver_address), room=room, data=data)
@@ -851,9 +819,6 @@ def _get_room_for_address(
851819
if room_ids: # if we know any room for this user, use the first one
852820
return self._client.rooms[room_ids[0]]
853821

854-
# The addresses are being sorted to ensure the same channel is used for both directions
855-
# of communication.
856-
# e.g.: raiden_ropsten_0xaaaa_0xbbbb
857822
address_pair = sorted([
858823
to_normalized_address(address)
859824
for address in [address, self._raiden_service.address]
@@ -1126,19 +1091,20 @@ def _recover(data: bytes, signature: bytes) -> Address:
11261091
signature=signature,
11271092
))
11281093

1129-
@staticmethod
1130-
@cachedmethod(_cachegetter('__address_cache', dict), key=attrgetter('user_id', 'displayname'))
1131-
def _validate_userid_signature(user: User) -> Optional[Address]:
1094+
@cached(_addresses_cache, key=lambda _, user: (user.user_id, user.displayname))
1095+
def _validate_userid_signature(self, user: User) -> Optional[Address]:
11321096
""" Validate a userId format and signature on displayName, and return its address"""
11331097
# display_name should be an address in the self._userid_re format
1134-
match = MatrixTransport._userid_re.match(user.user_id)
1098+
match = self._userid_re.match(user.user_id)
11351099
if not match:
11361100
return None
1137-
encoded_address: str = match.group(1)
1101+
1102+
encoded_address: AddressHex = match.group(1)
11381103
address: Address = to_canonical_address(encoded_address)
1104+
11391105
try:
11401106
displayname = user.get_display_name()
1141-
recovered = MatrixTransport._recover(
1107+
recovered = self._recover(
11421108
user.user_id.encode(),
11431109
decode_hex(displayname),
11441110
)
@@ -1154,14 +1120,19 @@ def _validate_userid_signature(user: User) -> Optional[Address]:
11541120
return None
11551121
return address
11561122

1157-
@cachedmethod(
1158-
_cachegetter('__users_cache', WeakValueDictionary),
1159-
key=lambda _, user: user.user_id if isinstance(user, User) else user,
1160-
)
11611123
def _get_user(self, user: Union[User, str]) -> User:
1162-
""" Creates an User from an user_id, if none, or fetch a cached User """
1163-
if not isinstance(user, User):
1164-
user = self._client.get_user(user)
1124+
"""Creates an User from an user_id, if none, or fetch a cached User
1125+
1126+
As all users are supposed to be in discovery room, its members dict is used for caching"""
1127+
user_id: str = getattr(user, 'user_id', user)
1128+
if self._discovery_room and user_id in self._discovery_room._members:
1129+
duser = self._discovery_room._members[user_id]
1130+
# if handed a User instance with displayname set, update the discovery room cache
1131+
if getattr(user, 'displayname', None):
1132+
duser.displayname = user.displayname
1133+
user = duser
1134+
elif not isinstance(user, User):
1135+
user = self._client.get_user(user_id)
11651136
return user
11661137

11671138
def _set_room_id_for_address(self, address: Address, room_id: Optional[_RoomID]):

0 commit comments

Comments
 (0)