@@ -1148,6 +1148,96 @@ async def test_read_message(
11481148 assert mess == expected_message
11491149 assert dict (stream_reader ._message_batches ) == batches_after
11501150
1151+ @pytest .mark .parametrize (
1152+ "batches_before,max_messages,actual_messages,batches_after" ,
1153+ [
1154+ (
1155+ {
1156+ 0 : PublicBatch (
1157+ messages = [stub_message (1 )],
1158+ _partition_session = stub_partition_session (),
1159+ _bytes_size = 0 ,
1160+ _codec = Codec .CODEC_RAW ,
1161+ )
1162+ },
1163+ None ,
1164+ 1 ,
1165+ {},
1166+ ),
1167+ (
1168+ {
1169+ 0 : PublicBatch (
1170+ messages = [stub_message (1 ), stub_message (2 )],
1171+ _partition_session = stub_partition_session (),
1172+ _bytes_size = 0 ,
1173+ _codec = Codec .CODEC_RAW ,
1174+ ),
1175+ 1 : PublicBatch (
1176+ messages = [stub_message (3 ), stub_message (4 )],
1177+ _partition_session = stub_partition_session (1 ),
1178+ _bytes_size = 0 ,
1179+ _codec = Codec .CODEC_RAW ,
1180+ ),
1181+ },
1182+ 1 ,
1183+ 1 ,
1184+ {
1185+ 1 : PublicBatch (
1186+ messages = [stub_message (3 ), stub_message (4 )],
1187+ _partition_session = stub_partition_session (1 ),
1188+ _bytes_size = 0 ,
1189+ _codec = Codec .CODEC_RAW ,
1190+ ),
1191+ 0 : PublicBatch (
1192+ messages = [stub_message (2 )],
1193+ _partition_session = stub_partition_session (),
1194+ _bytes_size = 0 ,
1195+ _codec = Codec .CODEC_RAW ,
1196+ ),
1197+ },
1198+ ),
1199+ (
1200+ {
1201+ 0 : PublicBatch (
1202+ messages = [stub_message (1 )],
1203+ _partition_session = stub_partition_session (),
1204+ _bytes_size = 0 ,
1205+ _codec = Codec .CODEC_RAW ,
1206+ ),
1207+ 1 : PublicBatch (
1208+ messages = [stub_message (2 ), stub_message (3 )],
1209+ _partition_session = stub_partition_session (1 ),
1210+ _bytes_size = 0 ,
1211+ _codec = Codec .CODEC_RAW ,
1212+ ),
1213+ },
1214+ 100 ,
1215+ 1 ,
1216+ {
1217+ 1 : PublicBatch (
1218+ messages = [stub_message (2 ), stub_message (3 )],
1219+ _partition_session = stub_partition_session (1 ),
1220+ _bytes_size = 0 ,
1221+ _codec = Codec .CODEC_RAW ,
1222+ )
1223+ },
1224+ ),
1225+ ],
1226+ )
1227+ async def test_read_batch_max_messages (
1228+ self ,
1229+ stream_reader ,
1230+ batches_before : typing .List [datatypes .PublicBatch ],
1231+ max_messages : typing .Optional [int ],
1232+ actual_messages : int ,
1233+ batches_after : typing .List [datatypes .PublicBatch ],
1234+ ):
1235+ stream_reader ._message_batches = OrderedDict (batches_before )
1236+ batch = stream_reader .receive_batch_nowait (max_messages = max_messages )
1237+
1238+ assert len (batch .messages ) == actual_messages
1239+ assert stream_reader ._message_batches == OrderedDict (batches_after )
1240+
11511241 async def test_receive_batch_nowait (self , stream , stream_reader , partition_session ):
11521242 assert stream_reader .receive_batch_nowait () is None
11531243
0 commit comments