@@ -97,6 +97,7 @@ async def wait_message(self):
9797
9898 async def receive_batch (
9999 self ,
100+ max_messages : typing .Union [int , None ] = None ,
100101 ) -> typing .Union [datatypes .PublicBatch , None ]:
101102 """
102103 Get one messages batch from reader.
@@ -105,7 +106,9 @@ async def receive_batch(
105106 use asyncio.wait_for for wait with timeout.
106107 """
107108 await self ._reconnector .wait_message ()
108- return self ._reconnector .receive_batch_nowait ()
109+ return self ._reconnector .receive_batch_nowait (
110+ max_messages = max_messages ,
111+ )
109112
110113 async def receive_message (self ) -> typing .Optional [datatypes .PublicMessage ]:
111114 """
@@ -212,8 +215,10 @@ async def wait_message(self):
212215 await self ._state_changed .wait ()
213216 self ._state_changed .clear ()
214217
215- def receive_batch_nowait (self ):
216- return self ._stream_reader .receive_batch_nowait ()
218+ def receive_batch_nowait (self , max_messages : Optional [int ] = None ):
219+ return self ._stream_reader .receive_batch_nowait (
220+ max_messages = max_messages ,
221+ )
217222
218223 def receive_message_nowait (self ):
219224 return self ._stream_reader .receive_message_nowait ()
@@ -363,17 +368,44 @@ def _get_first_batch(self) -> typing.Tuple[int, datatypes.PublicBatch]:
363368 first_id , batch = self ._message_batches .popitem (last = False )
364369 return first_id , batch
365370
366- def receive_batch_nowait (self ):
371+ def _cut_batch_by_max_messages (
372+ batch : datatypes .PublicBatch ,
373+ max_messages : int ,
374+ ) -> typing .Tuple [datatypes .PublicBatch , datatypes .PublicBatch ]:
375+ initial_length = len (batch .messages )
376+ one_message_size = batch ._bytes_size // initial_length
377+
378+ new_batch = datatypes .PublicBatch (
379+ messages = batch .messages [:max_messages ],
380+ _partition_session = batch ._partition_session ,
381+ _bytes_size = one_message_size * max_messages ,
382+ _codec = batch ._codec ,
383+ )
384+
385+ batch .messages = batch .messages [max_messages :]
386+ batch ._bytes_size = one_message_size * (initial_length - max_messages )
387+
388+ return new_batch , batch
389+
390+ def receive_batch_nowait (self , max_messages : Optional [int ] = None ):
367391 if self ._get_first_error ():
368392 raise self ._get_first_error ()
369393
370394 if not self ._message_batches :
371395 return None
372396
373- _ , batch = self ._get_first_batch ()
374- self ._buffer_release_bytes (batch ._bytes_size )
397+ part_sess_id , batch = self ._get_first_batch ()
398+
399+ if max_messages is None or len (batch .messages ) <= max_messages :
400+ self ._buffer_release_bytes (batch ._bytes_size )
401+ return batch
402+
403+ cutted_batch , remaining_batch = self ._cut_batch_by_max_messages (batch , max_messages )
404+
405+ self ._message_batches [part_sess_id ] = remaining_batch
406+ self ._buffer_release_bytes (cutted_batch ._bytes_size )
375407
376- return batch
408+ return cutted_batch
377409
378410 def receive_message_nowait (self ):
379411 if self ._get_first_error ():
0 commit comments