44import datetime
55import gzip
66import typing
7- from collections import deque
7+ from collections import OrderedDict
88from dataclasses import dataclass
99from unittest import mock
1010
@@ -52,9 +52,9 @@ def default_executor():
5252 executor .shutdown ()
5353
5454
55- def stub_partition_session ():
55+ def stub_partition_session (id : int = 0 ):
5656 return datatypes .PartitionSession (
57- id = 0 ,
57+ id = id ,
5858 state = datatypes .PartitionSession .State .Active ,
5959 topic_path = "asd" ,
6060 partition_id = 1 ,
@@ -212,21 +212,27 @@ def create_message(
212212 _commit_end_offset = partition_session ._next_message_start_commit_offset + offset_delta ,
213213 )
214214
215- async def send_message (self , stream_reader , message : PublicMessage ):
216- await self .send_batch (stream_reader , [message ])
215+ async def send_message (self , stream_reader , message : PublicMessage , new_batch = True ):
216+ await self .send_batch (stream_reader , [message ], new_batch = new_batch )
217217
218- async def send_batch (self , stream_reader , batch : typing .List [PublicMessage ]):
218+ async def send_batch (self , stream_reader , batch : typing .List [PublicMessage ], new_batch = True ):
219219 if len (batch ) == 0 :
220220 return
221221
222222 first_message = batch [0 ]
223223 for message in batch :
224224 assert message ._partition_session is first_message ._partition_session
225225
226+ partition_session_id = first_message ._partition_session .id
227+
226228 def batch_count ():
227229 return len (stream_reader ._message_batches )
228230
231+ def batch_size ():
232+ return len (stream_reader ._message_batches [partition_session_id ].messages )
233+
229234 initial_batches = batch_count ()
235+ initial_batch_size = batch_size () if not new_batch else 0
230236
231237 stream = stream_reader ._stream # type: StreamMock
232238 stream .from_server .put_nowait (
@@ -261,7 +267,10 @@ def batch_count():
261267 ),
262268 )
263269 )
264- await wait_condition (lambda : batch_count () > initial_batches )
270+ if new_batch :
271+ await wait_condition (lambda : batch_count () > initial_batches )
272+ else :
273+ await wait_condition (lambda : batch_size () > initial_batch_size )
265274
266275 async def test_unknown_error (self , stream , stream_reader_finish_with_error ):
267276 class TestError (Exception ):
@@ -412,15 +421,11 @@ async def test_commit_ranges_for_received_messages(
412421 m2 ._commit_start_offset = m1 .offset + 1
413422
414423 await self .send_message (stream_reader_started , m1 )
415- await self .send_message (stream_reader_started , m2 )
416-
417- await stream_reader_started .wait_messages ()
418- received = stream_reader_started .receive_batch_nowait ().messages
419- assert received == [m1 ]
424+ await self .send_message (stream_reader_started , m2 , new_batch = False )
420425
421426 await stream_reader_started .wait_messages ()
422427 received = stream_reader_started .receive_batch_nowait ().messages
423- assert received == [m2 ]
428+ assert received == [m1 , m2 ]
424429
425430 await stream_reader_started .close (False )
426431
@@ -860,7 +865,7 @@ def reader_batch_count():
860865
861866 assert stream_reader ._buffer_size_bytes == initial_buffer_size - bytes_size
862867
863- last_batch = stream_reader ._message_batches [ - 1 ]
868+ _ , last_batch = stream_reader ._message_batches . popitem ()
864869 assert last_batch == PublicBatch (
865870 messages = [
866871 PublicMessage (
@@ -1059,74 +1064,74 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti
10591064 @pytest .mark .parametrize (
10601065 "batches_before,expected_message,batches_after" ,
10611066 [
1062- ([] , None , [] ),
1067+ ({} , None , {} ),
10631068 (
1064- [
1065- PublicBatch (
1069+ {
1070+ 0 : PublicBatch (
10661071 messages = [stub_message (1 )],
10671072 _partition_session = stub_partition_session (),
10681073 _bytes_size = 0 ,
10691074 _codec = Codec .CODEC_RAW ,
10701075 )
1071- ] ,
1076+ } ,
10721077 stub_message (1 ),
1073- [] ,
1078+ {} ,
10741079 ),
10751080 (
1076- [
1077- PublicBatch (
1081+ {
1082+ 0 : PublicBatch (
10781083 messages = [stub_message (1 ), stub_message (2 )],
10791084 _partition_session = stub_partition_session (),
10801085 _bytes_size = 0 ,
10811086 _codec = Codec .CODEC_RAW ,
10821087 ),
1083- PublicBatch (
1088+ 1 : PublicBatch (
10841089 messages = [stub_message (3 ), stub_message (4 )],
1085- _partition_session = stub_partition_session (),
1090+ _partition_session = stub_partition_session (1 ),
10861091 _bytes_size = 0 ,
10871092 _codec = Codec .CODEC_RAW ,
10881093 ),
1089- ] ,
1094+ } ,
10901095 stub_message (1 ),
1091- [
1092- PublicBatch (
1096+ {
1097+ 0 : PublicBatch (
10931098 messages = [stub_message (2 )],
10941099 _partition_session = stub_partition_session (),
10951100 _bytes_size = 0 ,
10961101 _codec = Codec .CODEC_RAW ,
10971102 ),
1098- PublicBatch (
1103+ 1 : PublicBatch (
10991104 messages = [stub_message (3 ), stub_message (4 )],
1100- _partition_session = stub_partition_session (),
1105+ _partition_session = stub_partition_session (1 ),
11011106 _bytes_size = 0 ,
11021107 _codec = Codec .CODEC_RAW ,
11031108 ),
1104- ] ,
1109+ } ,
11051110 ),
11061111 (
1107- [
1108- PublicBatch (
1112+ {
1113+ 0 : PublicBatch (
11091114 messages = [stub_message (1 )],
11101115 _partition_session = stub_partition_session (),
11111116 _bytes_size = 0 ,
11121117 _codec = Codec .CODEC_RAW ,
11131118 ),
1114- PublicBatch (
1119+ 1 : PublicBatch (
11151120 messages = [stub_message (2 ), stub_message (3 )],
1116- _partition_session = stub_partition_session (),
1121+ _partition_session = stub_partition_session (1 ),
11171122 _bytes_size = 0 ,
11181123 _codec = Codec .CODEC_RAW ,
11191124 ),
1120- ] ,
1125+ } ,
11211126 stub_message (1 ),
1122- [
1123- PublicBatch (
1127+ {
1128+ 1 : PublicBatch (
11241129 messages = [stub_message (2 ), stub_message (3 )],
1125- _partition_session = stub_partition_session (),
1130+ _partition_session = stub_partition_session (1 ),
11261131 _bytes_size = 0 ,
11271132 _codec = Codec .CODEC_RAW ,
11281133 )
1129- ] ,
1134+ } ,
11301135 ),
11311136 ],
11321137 )
@@ -1137,11 +1142,11 @@ async def test_read_message(
11371142 expected_message : PublicMessage ,
11381143 batches_after : typing .List [datatypes .PublicBatch ],
11391144 ):
1140- stream_reader ._message_batches = deque (batches_before )
1145+ stream_reader ._message_batches = OrderedDict (batches_before )
11411146 mess = stream_reader .receive_message_nowait ()
11421147
11431148 assert mess == expected_message
1144- assert list (stream_reader ._message_batches ) == batches_after
1149+ assert dict (stream_reader ._message_batches ) == batches_after
11451150
11461151 async def test_receive_batch_nowait (self , stream , stream_reader , partition_session ):
11471152 assert stream_reader .receive_batch_nowait () is None
@@ -1152,30 +1157,21 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi
11521157 await self .send_message (stream_reader , mess1 )
11531158
11541159 mess2 = self .create_message (partition_session , 2 , 1 )
1155- await self .send_message (stream_reader , mess2 )
1160+ await self .send_message (stream_reader , mess2 , new_batch = False )
11561161
11571162 assert stream_reader ._buffer_size_bytes == initial_buffer_size - 2 * self .default_batch_size
11581163
11591164 received = stream_reader .receive_batch_nowait ()
11601165 assert received == PublicBatch (
1161- messages = [mess1 ],
1166+ messages = [mess1 , mess2 ],
11621167 _partition_session = mess1 ._partition_session ,
1163- _bytes_size = self .default_batch_size ,
1164- _codec = Codec .CODEC_RAW ,
1165- )
1166-
1167- received = stream_reader .receive_batch_nowait ()
1168- assert received == PublicBatch (
1169- messages = [mess2 ],
1170- _partition_session = mess2 ._partition_session ,
1171- _bytes_size = self .default_batch_size ,
1168+ _bytes_size = self .default_batch_size * 2 ,
11721169 _codec = Codec .CODEC_RAW ,
11731170 )
11741171
11751172 assert stream_reader ._buffer_size_bytes == initial_buffer_size
11761173
1177- assert StreamReadMessage .ReadRequest (self .default_batch_size ) == stream .from_client .get_nowait ().client_message
1178- assert StreamReadMessage .ReadRequest (self .default_batch_size ) == stream .from_client .get_nowait ().client_message
1174+ assert StreamReadMessage .ReadRequest (self .default_batch_size * 2 ) == stream .from_client .get_nowait ().client_message
11791175
11801176 with pytest .raises (asyncio .QueueEmpty ):
11811177 stream .from_client .get_nowait ()
@@ -1186,13 +1182,18 @@ async def test_receive_message_nowait(self, stream, stream_reader, partition_ses
11861182 initial_buffer_size = stream_reader ._buffer_size_bytes
11871183
11881184 await self .send_batch (
1189- stream_reader , [self .create_message (partition_session , 1 , 1 ), self .create_message (partition_session , 2 , 1 )]
1185+ stream_reader ,
1186+ [
1187+ self .create_message (partition_session , 1 , 1 ),
1188+ self .create_message (partition_session , 2 , 1 ),
1189+ ],
11901190 )
11911191 await self .send_batch (
11921192 stream_reader ,
11931193 [
11941194 self .create_message (partition_session , 10 , 1 ),
11951195 ],
1196+ new_batch = False ,
11961197 )
11971198
11981199 assert stream_reader ._buffer_size_bytes == initial_buffer_size - 2 * self .default_batch_size
0 commit comments