Skip to content

Commit 5716653

Browse files
Add lock on async cursor execute
1 parent 4f129aa commit 5716653

File tree

2 files changed

+82
-3
lines changed

2 files changed

+82
-3
lines changed

src/snowflake/connector/aio/_cursor.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from snowflake.connector.cursor import SnowflakeCursorBase as SnowflakeCursorBaseSync
4848
from snowflake.connector.cursor import T
4949
from snowflake.connector.errorcode import (
50+
ER_CURSOR_EXECUTE_IN_PROGRESS,
5051
ER_CURSOR_IS_CLOSED,
5152
ER_FAILED_PROCESSING_PYFORMAT,
5253
ER_FAILED_TO_REWRITE_MULTI_ROW_INSERT,
@@ -85,6 +86,7 @@ def __init__(
8586
self._lock_canceling = asyncio.Lock()
8687
self._timebomb: asyncio.Task | None = None
8788
self._prefetch_hook: typing.Callable[[], typing.Awaitable] | None = None
89+
self._executing: bool = False
8890

8991
def __aiter__(self):
9092
return self
@@ -552,16 +554,92 @@ async def execute(
552554
_force_qmark_paramstyle: bool = False,
553555
_dataframe_ast: str | None = None,
554556
) -> Self | dict[str, Any] | None:
555-
if _exec_async:
556-
_no_results = True
557-
logger.debug("executing SQL/command")
558557
if self.is_closed():
559558
Error.errorhandler_wrapper(
560559
self.connection,
561560
self,
562561
InterfaceError,
563562
{"msg": "Cursor is closed in execute.", "errno": ER_CURSOR_IS_CLOSED},
564563
)
564+
if self._executing:
565+
Error.errorhandler_wrapper(
566+
self.connection,
567+
self,
568+
InterfaceError,
569+
{
570+
"msg": "Another execute is already in progress on this cursor. "
571+
"Async cursors are not safe for concurrent use by multiple coroutines. "
572+
"Use a separate cursor for each concurrent operation.",
573+
"errno": ER_CURSOR_EXECUTE_IN_PROGRESS,
574+
},
575+
)
576+
577+
self._executing = True
578+
try:
579+
return await self._execute_impl(
580+
command=command,
581+
params=params,
582+
_bind_stage=_bind_stage,
583+
timeout=timeout,
584+
_exec_async=_exec_async,
585+
_no_retry=_no_retry,
586+
_do_reset=_do_reset,
587+
_put_callback=_put_callback,
588+
_put_azure_callback=_put_azure_callback,
589+
_put_callback_output_stream=_put_callback_output_stream,
590+
_get_callback=_get_callback,
591+
_get_azure_callback=_get_azure_callback,
592+
_get_callback_output_stream=_get_callback_output_stream,
593+
_show_progress_bar=_show_progress_bar,
594+
_statement_params=_statement_params,
595+
_is_internal=_is_internal,
596+
_describe_only=_describe_only,
597+
_no_results=_no_results,
598+
_is_put_get=_is_put_get,
599+
_raise_put_get_error=_raise_put_get_error,
600+
_force_put_overwrite=_force_put_overwrite,
601+
_skip_upload_on_content_match=_skip_upload_on_content_match,
602+
file_stream=file_stream,
603+
num_statements=num_statements,
604+
_force_qmark_paramstyle=_force_qmark_paramstyle,
605+
_dataframe_ast=_dataframe_ast,
606+
)
607+
finally:
608+
self._executing = False
609+
610+
async def _execute_impl(
611+
self,
612+
command: str,
613+
params: Sequence[Any] | dict[Any, Any] | None = None,
614+
_bind_stage: str | None = None,
615+
timeout: int | None = None,
616+
_exec_async: bool = False,
617+
_no_retry: bool = False,
618+
_do_reset: bool = True,
619+
_put_callback: SnowflakeProgressPercentage = None,
620+
_put_azure_callback: SnowflakeProgressPercentage = None,
621+
_put_callback_output_stream: IO[str] = sys.stdout,
622+
_get_callback: SnowflakeProgressPercentage = None,
623+
_get_azure_callback: SnowflakeProgressPercentage = None,
624+
_get_callback_output_stream: IO[str] = sys.stdout,
625+
_show_progress_bar: bool = True,
626+
_statement_params: dict[str, str] | None = None,
627+
_is_internal: bool = False,
628+
_describe_only: bool = False,
629+
_no_results: bool = False,
630+
_is_put_get: bool | None = None,
631+
_raise_put_get_error: bool = True,
632+
_force_put_overwrite: bool = False,
633+
_skip_upload_on_content_match: bool = False,
634+
file_stream: IO[bytes] | None = None,
635+
num_statements: int | None = None,
636+
_force_qmark_paramstyle: bool = False,
637+
_dataframe_ast: str | None = None,
638+
) -> Self | dict[str, Any] | None:
639+
"""Internal implementation of execute, called after concurrency checks."""
640+
if _exec_async:
641+
_no_results = True
642+
logger.debug("executing SQL/command")
565643

566644
if _do_reset:
567645
self.reset()

src/snowflake/connector/errorcode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
ER_CHUNK_DOWNLOAD_FAILED = 252010
5151
ER_NOT_IMPLICITY_SNOWFLAKE_DATATYPE = 252011
5252
ER_FAILED_PROCESSING_QMARK = 252012
53+
ER_CURSOR_EXECUTE_IN_PROGRESS = 252013
5354

5455
# file_transfer
5556
ER_INVALID_STAGE_FS = 253001

0 commit comments

Comments
 (0)