@@ -99,6 +99,7 @@ async def wait_message(self):
9999
100100 async def receive_batch (
101101 self ,
102+ max_messages : typing .Union [int , None ] = None ,
102103 ) -> typing .Union [datatypes .PublicBatch , None ]:
103104 """
104105 Get one messages batch from reader.
@@ -107,7 +108,9 @@ async def receive_batch(
107108 use asyncio.wait_for for wait with timeout.
108109 """
109110 await self ._reconnector .wait_message ()
110- return self ._reconnector .receive_batch_nowait ()
111+ return self ._reconnector .receive_batch_nowait (
112+ max_messages = max_messages ,
113+ )
111114
112115 async def receive_message (self ) -> typing .Optional [datatypes .PublicMessage ]:
113116 """
@@ -214,8 +217,10 @@ async def wait_message(self):
214217 await self ._state_changed .wait ()
215218 self ._state_changed .clear ()
216219
217- def receive_batch_nowait (self ):
218- return self ._stream_reader .receive_batch_nowait ()
220+ def receive_batch_nowait (self , max_messages : Optional [int ] = None ):
221+ return self ._stream_reader .receive_batch_nowait (
222+ max_messages = max_messages ,
223+ )
219224
220225 def receive_message_nowait (self ):
221226 return self ._stream_reader .receive_message_nowait ()
@@ -383,17 +388,44 @@ def _get_first_batch(self) -> typing.Tuple[int, datatypes.PublicBatch]:
383388 partition_session_id , batch = self ._message_batches .popitem (last = False )
384389 return partition_session_id , batch
385390
386- def receive_batch_nowait (self ):
391+ def _cut_batch_by_max_messages (
392+ batch : datatypes .PublicBatch ,
393+ max_messages : int ,
394+ ) -> typing .Tuple [datatypes .PublicBatch , datatypes .PublicBatch ]:
395+ initial_length = len (batch .messages )
396+ one_message_size = batch ._bytes_size // initial_length
397+
398+ new_batch = datatypes .PublicBatch (
399+ messages = batch .messages [:max_messages ],
400+ _partition_session = batch ._partition_session ,
401+ _bytes_size = one_message_size * max_messages ,
402+ _codec = batch ._codec ,
403+ )
404+
405+ batch .messages = batch .messages [max_messages :]
406+ batch ._bytes_size = one_message_size * (initial_length - max_messages )
407+
408+ return new_batch , batch
409+
410+ def receive_batch_nowait (self , max_messages : Optional [int ] = None ):
387411 if self ._get_first_error ():
388412 raise self ._get_first_error ()
389413
390414 if not self ._message_batches :
391415 return None
392416
393- _ , batch = self ._get_first_batch ()
394- self ._buffer_release_bytes (batch ._bytes_size )
417+ part_sess_id , batch = self ._get_first_batch ()
418+
419+ if max_messages is None or len (batch .messages ) <= max_messages :
420+ self ._buffer_release_bytes (batch ._bytes_size )
421+ return batch
422+
423+ cutted_batch , remaining_batch = self ._cut_batch_by_max_messages (batch , max_messages )
424+
425+ self ._message_batches [part_sess_id ] = remaining_batch
426+ self ._buffer_release_bytes (cutted_batch ._bytes_size )
395427
396- return batch
428+ return cutted_batch
397429
398430 def receive_message_nowait (self ):
399431 if self ._get_first_error ():
0 commit comments