|
47 | 47 | from snowflake.connector.cursor import SnowflakeCursorBase as SnowflakeCursorBaseSync |
48 | 48 | from snowflake.connector.cursor import T |
49 | 49 | from snowflake.connector.errorcode import ( |
| 50 | + ER_CURSOR_EXECUTE_IN_PROGRESS, |
50 | 51 | ER_CURSOR_IS_CLOSED, |
51 | 52 | ER_FAILED_PROCESSING_PYFORMAT, |
52 | 53 | ER_FAILED_TO_REWRITE_MULTI_ROW_INSERT, |
@@ -85,6 +86,7 @@ def __init__( |
85 | 86 | self._lock_canceling = asyncio.Lock() |
86 | 87 | self._timebomb: asyncio.Task | None = None |
87 | 88 | self._prefetch_hook: typing.Callable[[], typing.Awaitable] | None = None |
| 89 | + self._executing: bool = False |
88 | 90 |
|
89 | 91 | def __aiter__(self): |
90 | 92 | return self |
@@ -552,16 +554,92 @@ async def execute( |
552 | 554 | _force_qmark_paramstyle: bool = False, |
553 | 555 | _dataframe_ast: str | None = None, |
554 | 556 | ) -> Self | dict[str, Any] | None: |
555 | | - if _exec_async: |
556 | | - _no_results = True |
557 | | - logger.debug("executing SQL/command") |
558 | 557 | if self.is_closed(): |
559 | 558 | Error.errorhandler_wrapper( |
560 | 559 | self.connection, |
561 | 560 | self, |
562 | 561 | InterfaceError, |
563 | 562 | {"msg": "Cursor is closed in execute.", "errno": ER_CURSOR_IS_CLOSED}, |
564 | 563 | ) |
| 564 | + if self._executing: |
| 565 | + Error.errorhandler_wrapper( |
| 566 | + self.connection, |
| 567 | + self, |
| 568 | + InterfaceError, |
| 569 | + { |
| 570 | + "msg": "Another execute is already in progress on this cursor. " |
| 571 | + "Async cursors are not safe for concurrent use by multiple coroutines. " |
| 572 | + "Use a separate cursor for each concurrent operation.", |
| 573 | + "errno": ER_CURSOR_EXECUTE_IN_PROGRESS, |
| 574 | + }, |
| 575 | + ) |
| 576 | + |
| 577 | + self._executing = True |
| 578 | + try: |
| 579 | + return await self._execute_impl( |
| 580 | + command=command, |
| 581 | + params=params, |
| 582 | + _bind_stage=_bind_stage, |
| 583 | + timeout=timeout, |
| 584 | + _exec_async=_exec_async, |
| 585 | + _no_retry=_no_retry, |
| 586 | + _do_reset=_do_reset, |
| 587 | + _put_callback=_put_callback, |
| 588 | + _put_azure_callback=_put_azure_callback, |
| 589 | + _put_callback_output_stream=_put_callback_output_stream, |
| 590 | + _get_callback=_get_callback, |
| 591 | + _get_azure_callback=_get_azure_callback, |
| 592 | + _get_callback_output_stream=_get_callback_output_stream, |
| 593 | + _show_progress_bar=_show_progress_bar, |
| 594 | + _statement_params=_statement_params, |
| 595 | + _is_internal=_is_internal, |
| 596 | + _describe_only=_describe_only, |
| 597 | + _no_results=_no_results, |
| 598 | + _is_put_get=_is_put_get, |
| 599 | + _raise_put_get_error=_raise_put_get_error, |
| 600 | + _force_put_overwrite=_force_put_overwrite, |
| 601 | + _skip_upload_on_content_match=_skip_upload_on_content_match, |
| 602 | + file_stream=file_stream, |
| 603 | + num_statements=num_statements, |
| 604 | + _force_qmark_paramstyle=_force_qmark_paramstyle, |
| 605 | + _dataframe_ast=_dataframe_ast, |
| 606 | + ) |
| 607 | + finally: |
| 608 | + self._executing = False |
| 609 | + |
| 610 | + async def _execute_impl( |
| 611 | + self, |
| 612 | + command: str, |
| 613 | + params: Sequence[Any] | dict[Any, Any] | None = None, |
| 614 | + _bind_stage: str | None = None, |
| 615 | + timeout: int | None = None, |
| 616 | + _exec_async: bool = False, |
| 617 | + _no_retry: bool = False, |
| 618 | + _do_reset: bool = True, |
| 619 | + _put_callback: SnowflakeProgressPercentage = None, |
| 620 | + _put_azure_callback: SnowflakeProgressPercentage = None, |
| 621 | + _put_callback_output_stream: IO[str] = sys.stdout, |
| 622 | + _get_callback: SnowflakeProgressPercentage = None, |
| 623 | + _get_azure_callback: SnowflakeProgressPercentage = None, |
| 624 | + _get_callback_output_stream: IO[str] = sys.stdout, |
| 625 | + _show_progress_bar: bool = True, |
| 626 | + _statement_params: dict[str, str] | None = None, |
| 627 | + _is_internal: bool = False, |
| 628 | + _describe_only: bool = False, |
| 629 | + _no_results: bool = False, |
| 630 | + _is_put_get: bool | None = None, |
| 631 | + _raise_put_get_error: bool = True, |
| 632 | + _force_put_overwrite: bool = False, |
| 633 | + _skip_upload_on_content_match: bool = False, |
| 634 | + file_stream: IO[bytes] | None = None, |
| 635 | + num_statements: int | None = None, |
| 636 | + _force_qmark_paramstyle: bool = False, |
| 637 | + _dataframe_ast: str | None = None, |
| 638 | + ) -> Self | dict[str, Any] | None: |
| 639 | + """Internal implementation of execute, called after concurrency checks.""" |
| 640 | + if _exec_async: |
| 641 | + _no_results = True |
| 642 | + logger.debug("executing SQL/command") |
565 | 643 |
|
566 | 644 | if _do_reset: |
567 | 645 | self.reset() |
|
0 commit comments