|
3 | 3 | import asyncio |
4 | 4 | import concurrent.futures |
5 | 5 | import gzip |
| 6 | +import random |
6 | 7 | import typing |
7 | 8 | from asyncio import Task |
8 | | -from collections import deque |
9 | 9 | from typing import Optional, Set, Dict, Union, Callable |
10 | 10 |
|
11 | 11 | import ydb |
@@ -264,7 +264,7 @@ class ReaderStream: |
264 | 264 |
|
265 | 265 | _state_changed: asyncio.Event |
266 | 266 | _closed: bool |
267 | | - _message_batches: typing.Deque[datatypes.PublicBatch] |
| 267 | + _message_batches: typing.Dict[int, datatypes.PublicBatch] |
268 | 268 | _first_error: asyncio.Future[YdbError] |
269 | 269 |
|
270 | 270 | _update_token_interval: Union[int, float] |
@@ -296,7 +296,7 @@ def __init__( |
296 | 296 | self._closed = False |
297 | 297 | self._first_error = asyncio.get_running_loop().create_future() |
298 | 298 | self._batches_to_decode = asyncio.Queue() |
299 | | - self._message_batches = deque() |
| 299 | + self._message_batches = dict() |
300 | 300 |
|
301 | 301 | self._update_token_interval = settings.update_token_interval |
302 | 302 | self._get_token_function = get_token_function |
@@ -359,29 +359,36 @@ async def wait_messages(self): |
359 | 359 | await self._state_changed.wait() |
360 | 360 | self._state_changed.clear() |
361 | 361 |
|
| 362 | + def _get_random_batch(self): |
| 363 | + rnd_id = random.choice(list(self._message_batches.keys())) |
| 364 | + return rnd_id, self._message_batches[rnd_id] |
| 365 | + |
362 | 366 | def receive_batch_nowait(self): |
363 | 367 | if self._get_first_error(): |
364 | 368 | raise self._get_first_error() |
365 | 369 |
|
366 | 370 | if not self._message_batches: |
367 | 371 | return None |
368 | 372 |
|
369 | | - batch = self._message_batches.popleft() |
| 373 | + part_sess_id, batch = self._get_random_batch() |
370 | 374 | self._buffer_release_bytes(batch._bytes_size) |
| 375 | + del self._message_batches[part_sess_id] |
| 376 | + |
371 | 377 | return batch |
372 | 378 |
|
373 | 379 | def receive_message_nowait(self): |
374 | 380 | if self._get_first_error(): |
375 | 381 | raise self._get_first_error() |
376 | 382 |
|
377 | | - try: |
378 | | - batch = self._message_batches[0] |
379 | | - message = batch.pop_message() |
380 | | - except IndexError: |
| 383 | + if not self._message_batches: |
381 | 384 | return None |
382 | 385 |
|
383 | | - if batch.empty(): |
384 | | - self.receive_batch_nowait() |
| 386 | + part_sess_id, batch = self._get_random_batch() |
| 387 | + |
| 388 | + message = batch.messages.pop(0) |
| 389 | + if len(batch.messages) == 0: |
| 390 | + self._buffer_release_bytes(batch._bytes_size) |
| 391 | + del self._message_batches[part_sess_id] |
385 | 392 |
|
386 | 393 | return message |
387 | 394 |
|
@@ -605,9 +612,18 @@ async def _decode_batches_loop(self): |
605 | 612 | while True: |
606 | 613 | batch = await self._batches_to_decode.get() |
607 | 614 | await self._decode_batch_inplace(batch) |
608 | | - self._message_batches.append(batch) |
| 615 | + self._add_batch_to_queue(batch) |
609 | 616 | self._state_changed.set() |
610 | 617 |
|
| 618 | + def _add_batch_to_queue(self, batch: datatypes.PublicBatch): |
| 619 | + part_sess_id = batch._partition_session.id |
| 620 | + if part_sess_id in self._message_batches: |
| 621 | + self._message_batches[part_sess_id].messages.extend(batch.messages) |
| 622 | + self._message_batches[part_sess_id]._bytes_size += batch._bytes_size |
| 623 | + return |
| 624 | + |
| 625 | + self._message_batches[part_sess_id] = batch |
| 626 | + |
611 | 627 | async def _decode_batch_inplace(self, batch): |
612 | 628 | if batch._codec == Codec.CODEC_RAW: |
613 | 629 | return |
|
0 commit comments