@@ -115,6 +115,123 @@ async def handle_server_messages() -> None:
115115 await server_handler
116116
117117
118+ async def test_subscription_cancel (ws_server : WsServerFixture ) -> None :
119+ (urimeta , recv , conn ) = ws_server
120+
121+ client = Client (
122+ client_id = "CLIENT1" ,
123+ server_id = "SERVER" ,
124+ transport_options = TransportOptions (),
125+ uri_and_metadata_factory = urimeta ,
126+ )
127+
128+ connecting = asyncio .create_task (client .ensure_connected ())
129+ request_msg = parse_transport_msg (await recv .get ())
130+
131+ assert not isinstance (request_msg , str )
132+ assert (serverconn := conn ())
133+ handshake_request : ControlMessageHandshakeRequest [None ] = (
134+ ControlMessageHandshakeRequest (** request_msg .payload )
135+ )
136+
137+ handshake_resp = ControlMessageHandshakeResponse (
138+ status = HandShakeStatus (
139+ ok = True ,
140+ ),
141+ )
142+ handshake_request .sessionId
143+
144+ msg = TransportMessage (
145+ from_ = request_msg .from_ ,
146+ to = request_msg .to ,
147+ streamId = request_msg .streamId ,
148+ controlFlags = 0 ,
149+ id = nanoid .generate (),
150+ seq = 0 ,
151+ ack = 0 ,
152+ payload = handshake_resp .model_dump (),
153+ )
154+ packed = msgpack .packb (
155+ msg .model_dump (by_alias = True , exclude_none = True ), datetime = True
156+ )
157+ await serverconn .send (packed )
158+
159+ received_waiter = asyncio .Event ()
160+
161+ async def handle_server_messages () -> None :
162+ request_msg = parse_transport_msg (await recv .get ())
163+ assert not isinstance (request_msg , str )
164+
165+ logging .debug ("request_msg: %r" , repr (request_msg ))
166+ seq = 0
167+
168+ while True :
169+ try :
170+ cancel_msg = parse_transport_msg (recv .get_nowait ())
171+ break
172+ except asyncio .queues .QueueEmpty :
173+ pass
174+
175+ msg = TransportMessage (
176+ from_ = request_msg .from_ ,
177+ to = request_msg .to ,
178+ streamId = request_msg .streamId ,
179+ controlFlags = 0 ,
180+ id = nanoid .generate (),
181+ seq = seq ,
182+ ack = 0 ,
183+ payload = {
184+ "ok" : True ,
185+ "payload" : {
186+ "hello" : "world" ,
187+ },
188+ },
189+ )
190+ seq += 1
191+ packed = msgpack .packb (
192+ msg .model_dump (by_alias = True , exclude_none = True ), datetime = True
193+ )
194+ await serverconn .send (packed )
195+ await asyncio .sleep (0.1 )
196+
197+ if seq > 5 :
198+ received_waiter .set ()
199+
200+ assert not isinstance (cancel_msg , str )
201+ assert cancel_msg .controlFlags == STREAM_CANCEL_BIT
202+
203+ server_handler = asyncio .create_task (handle_server_messages ())
204+
205+ async def receive_chunks () -> None :
206+ async for chunk in client .send_subscription (
207+ "test" ,
208+ "bigstream" ,
209+ {},
210+ lambda x : x ,
211+ lambda x : x ,
212+ lambda x : x ,
213+ ):
214+ print (repr (chunk ))
215+
216+ receive_task = asyncio .create_task (receive_chunks ())
217+
218+ # Wait until we've seen at least a few messages from the upload Task
219+ await received_waiter .wait ()
220+
221+ receive_task .cancel ()
222+ try :
223+ await receive_task
224+ except asyncio .CancelledError :
225+ pass
226+
227+ await client .close ()
228+ await connecting
229+
230+ # Ensure we're listening to close messages as well
231+ server_handler .cancel ()
232+ await server_handler
233+
234+
118235async def test_upload_cancel (ws_server : WsServerFixture ) -> None :
119236 (urimeta , recv , conn ) = ws_server
120237
0 commit comments