1313
1414from examples .tutorial .reactivex .shared import (Message , chat_filename_mimetype , ClientStatistics ,
1515 ServerStatisticsRequest , ServerStatistics , dataclass_to_payload ,
16- decode_dataclass )
16+ decode_dataclass , decode_payload )
1717from rsocket .extensions .composite_metadata import CompositeMetadata
1818from rsocket .extensions .helpers import composite , metadata_item
1919from rsocket .frame_helpers import ensure_bytes
@@ -36,15 +36,15 @@ class SessionId(str): # allow weak reference
3636@dataclass ()
3737class UserSessionData :
3838 username : str
39- session_id : str
39+ session_id : SessionId
4040 messages : Queue = field (default_factory = Queue )
4141 statistics : Optional [ClientStatistics ] = None
4242 requested_statistics : ServerStatisticsRequest = field (default_factory = ServerStatisticsRequest )
4343
4444
4545@dataclass (frozen = True )
4646class ChatData :
47- channel_users : Dict [str , Set [str ]] = field (default_factory = lambda : defaultdict (WeakSet ))
47+ channel_users : Dict [str , Set [SessionId ]] = field (default_factory = lambda : defaultdict (WeakSet ))
4848 files : Dict [str , bytes ] = field (default_factory = dict )
4949 channel_messages : Dict [str , Queue ] = field (default_factory = lambda : defaultdict (Queue ))
5050 user_session_by_id : Dict [str , UserSessionData ] = field (default_factory = WeakValueDictionary )
@@ -95,6 +95,13 @@ def new_statistics_data(requested_statistics: ServerStatisticsRequest):
9595 return ServerStatistics (** statistics_data )
9696
9797
98+ def find_username_by_session (session_id : SessionId ) -> Optional [str ]:
99+ session = chat_data .user_session_by_id .get (session_id )
100+ if session is None :
101+ return None
102+ return session .username
103+
104+
98105class ChatUserSession :
99106
100107 def __init__ (self ):
@@ -105,11 +112,10 @@ def remove(self):
105112 del chat_data .user_session_by_id [self ._session .session_id ]
106113
107114 def router_factory (self ):
108- router = RequestRouter ()
115+ router = RequestRouter (payload_mapper = decode_payload )
109116
110117 @router .response ('login' )
111- async def login (payload : Payload ) -> Observable :
112- username = utf8_decode (payload .data )
118+ async def login (username : str ) -> Observable :
113119 logging .info (f'New user: { username } ' )
114120 session_id = SessionId (uuid .uuid4 ())
115121 self ._session = UserSessionData (username , session_id )
@@ -118,15 +124,13 @@ async def login(payload: Payload) -> Observable:
118124 return reactivex .just (Payload (ensure_bytes (session_id )))
119125
120126 @router .response ('channel.join' )
121- async def join_channel (payload : Payload ) -> Observable :
122- channel_name = utf8_decode (payload .data )
127+ async def join_channel (channel_name : str ) -> Observable :
123128 ensure_channel_exists (channel_name )
124129 chat_data .channel_users [channel_name ].add (self ._session .session_id )
125130 return reactivex .empty ()
126131
127132 @router .response ('channel.leave' )
128- async def leave_channel (payload : Payload ) -> Observable :
129- channel_name = utf8_decode (payload .data )
133+ async def leave_channel (channel_name : str ) -> Observable :
130134 chat_data .channel_users [channel_name ].discard (self ._session .session_id )
131135 return reactivex .empty ()
132136
@@ -157,10 +161,17 @@ async def get_channels() -> Observable:
157161 return reactivex .from_iterable (
158162 (Payload (ensure_bytes (channel )) for channel in chat_data .channel_messages .keys ()))
159163
160- @router .fire_and_forget ('statistics' )
161- async def receive_statistics (payload : Payload ):
162- statistics = decode_dataclass (payload .data , ClientStatistics )
164+ @router .stream ('channel.users' )
165+ async def get_channel_users (channel_name : str ) -> Observable :
166+ if channel_name not in chat_data .channel_users :
167+ return reactivex .empty ()
163168
169+ return reactivex .from_iterable (Payload (ensure_bytes (find_username_by_session (session_id ))) for
170+ session_id in
171+ chat_data .channel_users [channel_name ])
172+
173+ @router .fire_and_forget ('statistics' )
174+ async def receive_statistics (statistics : ClientStatistics ):
164175 logging .info ('Received client statistics. memory usage: %s' , statistics .memory_usage )
165176
166177 self ._session .statistics = statistics
@@ -198,9 +209,7 @@ def on_next(payload: Payload):
198209 limit_rate = 2 )
199210
200211 @router .response ('message' )
201- async def send_message (payload : Payload ) -> Observable :
202- message = decode_dataclass (payload .data , Message )
203-
212+ async def send_message (message : Message ) -> Observable :
204213 logging .info ('Received message for user: %s, channel: %s' , message .user , message .channel )
205214
206215 target_message = Message (self ._session .username , message .content , message .channel )
0 commit comments