Skip to content

Commit fb08011

Browse files
authored
SNOW-2437690 make sproc cancel API work with public connector (#3914)
1 parent 01dead4 commit fb08011

File tree

3 files changed

+164
-7
lines changed

3 files changed

+164
-7
lines changed

src/snowflake/snowpark/async_job.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,34 @@ def cancel(self) -> None:
284284
"ENABLE_ASYNC_QUERY_IN_PYTHON_STORED_PROCS", False
285285
)
286286
):
287-
cancel_resp = self._session._conn._conn.cancel_query(self.query_id)
288-
if not cancel_resp.get("success", False):
287+
import _snowflake
288+
import json
289+
import uuid
290+
291+
try:
292+
uuid.UUID(self.query_id)
293+
except ValueError:
294+
raise ValueError(f"Invalid UUID: '{self.query_id}'")
295+
296+
raw_cancel_resp = _snowflake.cancel_query(self.query_id)
297+
298+
# Set failure_response when
299+
# - success != True in the response or
300+
# - cannot parse the response at all.
301+
failure_response = None
302+
try:
303+
parsed_cancel_resp = json.loads(raw_cancel_resp)
304+
if not parsed_cancel_resp.get("success", False):
305+
failure_response = parsed_cancel_resp
306+
except (TypeError, json.JSONDecodeError) as e:
307+
failure_response = {
308+
"success": False,
309+
"error": f"Error parsing response: {e}",
310+
}
311+
312+
if failure_response:
289313
raise DatabaseError(
290-
f"Failed to cancel query. Returned response: {cancel_resp}"
314+
f"Failed to cancel query. Returned response: {failure_response}"
291315
)
292316
else:
293317
self._cursor.execute(f"select SYSTEM$CANCEL_QUERY('{self.query_id}')")

tests/integ/scala/test_async_job_suite.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,10 +369,6 @@ def test_async_batch_insert(session):
369369
analyzer.ARRAY_BIND_THRESHOLD = original_value
370370

371371

372-
@pytest.mark.skipif(
373-
IS_IN_STORED_PROC,
374-
reason="TODO(SNOW-932722): Cancel query is not allowed in stored proc",
375-
)
376372
def test_async_is_running_and_cancel(session):
377373
async_job = session.sql("select SYSTEM$WAIT(3)").collect_nowait()
378374
while not async_job.is_done():

tests/unit/test_async_job.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#
2+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
import sys
6+
import uuid
7+
from types import SimpleNamespace
8+
from unittest import mock
9+
from unittest.mock import MagicMock
10+
from enum import Enum
11+
12+
import pytest
13+
14+
from snowflake.connector.errors import DatabaseError
15+
from snowflake.snowpark.async_job import AsyncJob
16+
from snowflake.snowpark.session import Session
17+
18+
19+
class CancelQueryErrorType(Enum):
20+
"""This is test-only enum to indicate what error we should mock for."""
21+
22+
VALID_JSON_BUT_FAIL = 0
23+
INVALID_JSON = 1
24+
25+
26+
def _make_session_with_cursor() -> tuple[Session, mock.MagicMock]:
27+
server_conn = MagicMock()
28+
inner_conn = mock.MagicMock()
29+
mock_cursor = mock.MagicMock()
30+
inner_conn.cursor.return_value = mock_cursor
31+
server_conn._conn = inner_conn
32+
33+
# Return True for async feature flag.
34+
def _param_side_effect(name, default_value):
35+
if name and name.upper() == "ENABLE_ASYNC_QUERY_IN_PYTHON_STORED_PROCS":
36+
return True
37+
return default_value
38+
39+
server_conn._get_client_side_session_parameter.side_effect = _param_side_effect
40+
session = Session(server_conn)
41+
return session, mock_cursor
42+
43+
44+
def test_async_job_cancel_executes_sys_func_in_regular_client(monkeypatch):
45+
session, mock_cursor = _make_session_with_cursor()
46+
47+
# Test with regular client / non-sproc path.
48+
monkeypatch.setattr(
49+
"snowflake.snowpark.async_job.is_in_stored_procedure", lambda: False
50+
)
51+
52+
qid = str(uuid.uuid4())
53+
54+
job = AsyncJob(query_id=qid, query=None, session=session)
55+
job.cancel()
56+
57+
mock_cursor.execute.assert_called_once_with(f"select SYSTEM$CANCEL_QUERY('{qid}')")
58+
59+
60+
def test_async_job_cancel_in_sproc_success(monkeypatch):
61+
session, mock_cursor = _make_session_with_cursor()
62+
63+
# Test with sproc path.
64+
monkeypatch.setattr(
65+
"snowflake.snowpark.async_job.is_in_stored_procedure", lambda: True
66+
)
67+
68+
# Inject a fake _snowflake module with a mock cancel_query.
69+
cancel_mock = mock.MagicMock(return_value='{"success": true}')
70+
fake_mod = SimpleNamespace(cancel_query=cancel_mock)
71+
monkeypatch.setitem(sys.modules, "_snowflake", fake_mod)
72+
73+
qid = str(uuid.uuid4())
74+
job = AsyncJob(query_id=qid, query=None, session=session)
75+
job.cancel() # should not raise
76+
77+
# Verify that _snowflake.cancel_query is called with given query id.
78+
cancel_mock.assert_called_once_with(qid)
79+
# Verify that the system function is not called.
80+
mock_cursor.execute.assert_not_called()
81+
82+
83+
@pytest.mark.parametrize(
84+
"fake_response_and_type",
85+
[
86+
# Valid JSON, but success=false.
87+
(
88+
'{"success": false, "error": "boom"}',
89+
CancelQueryErrorType.VALID_JSON_BUT_FAIL,
90+
),
91+
# Invalid JSON.
92+
("{", CancelQueryErrorType.INVALID_JSON),
93+
],
94+
)
95+
def test_async_job_cancel_in_sproc_failure_in_cancel_query(
96+
monkeypatch, fake_response_and_type
97+
):
98+
session, _ = _make_session_with_cursor()
99+
100+
# Test with sproc path.
101+
monkeypatch.setattr(
102+
"snowflake.snowpark.async_job.is_in_stored_procedure", lambda: True
103+
)
104+
105+
# Mock _snowflake.cancel_query to return error response as instructed to.
106+
fake_response, error_type = fake_response_and_type
107+
fake_mod = SimpleNamespace(cancel_query=lambda qid: fake_response)
108+
monkeypatch.setitem(sys.modules, "_snowflake", fake_mod)
109+
110+
job = AsyncJob(query_id=str(uuid.uuid4()), query=None, session=session)
111+
with pytest.raises(DatabaseError, match="Failed to cancel query") as exc_info:
112+
job.cancel()
113+
if error_type == CancelQueryErrorType.VALID_JSON_BUT_FAIL:
114+
assert "Error parsing response" not in str(exc_info.value)
115+
elif error_type == CancelQueryErrorType.INVALID_JSON:
116+
assert "Error parsing response" in str(exc_info.value)
117+
else:
118+
raise ValueError(f"Invalid test case: {fake_response_and_type}")
119+
120+
121+
def test_async_job_cancel_in_sproc_failure_in_uuid_validation(monkeypatch):
122+
session, _ = _make_session_with_cursor()
123+
124+
# Test with sproc path.
125+
monkeypatch.setattr(
126+
"snowflake.snowpark.async_job.is_in_stored_procedure", lambda: True
127+
)
128+
129+
# Inject a fake _snowflake module with a mock cancel_query.
130+
cancel_mock = mock.MagicMock(return_value='{"success": true}')
131+
fake_mod = SimpleNamespace(cancel_query=cancel_mock)
132+
monkeypatch.setitem(sys.modules, "_snowflake", fake_mod)
133+
134+
qid = "qid-invalid-123"
135+
job = AsyncJob(query_id=qid, query=None, session=session)
136+
with pytest.raises(ValueError, match=f"Invalid UUID: '{qid}'"):
137+
job.cancel()

0 commit comments

Comments
 (0)