Skip to content

Commit f0ed2d4

Browse files
fix sql_executor
1 parent 3ac92a8 commit f0ed2d4

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

src/snowflake/cli/api/sql_execution.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,10 @@ def use(self, object_type: ObjectType, name: str):
126126
# Rewrite the error to make the message more useful.
127127
raise CouldNotUseObjectError(object_type=object_type, name=name) from err
128128

129-
def current_role(self) -> str:
130-
return to_identifier(self.execute_query(f"select current_role()").fetchone()[0])
129+
def current_role(self) -> Optional[str]:
130+
if result := self.execute_query(f"select current_role()").fetchone()[0]:
131+
return to_identifier(result)
132+
return None
131133

132134
@contextmanager
133135
def use_role(self, new_role: str):
@@ -137,14 +139,17 @@ def use_role(self, new_role: str):
137139
"""
138140
new_role = to_identifier(new_role)
139141
prev_role = self.current_role()
140-
is_different_role = new_role.lower() != prev_role.lower()
142+
if prev_role:
143+
is_different_role = new_role.lower() != prev_role.lower()
144+
else:
145+
is_different_role = True
141146
if is_different_role:
142147
self._log.debug("Assuming different role: %s", new_role)
143148
self.execute_query(f"use role {new_role}")
144149
try:
145150
yield
146151
finally:
147-
if is_different_role:
152+
if is_different_role and prev_role:
148153
self.execute_query(f"use role {prev_role}")
149154

150155
def session_has_warehouse(self) -> bool:

tests/api/test_sql_execution.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,17 @@ def test_use_role_same_id(mock_execute_query, mock_cursor, new_role, current_rol
6363
assert mock_execute_query.mock_calls == [mock.call("select current_role()")]
6464

6565

66+
@mock.patch(EXECUTE_QUERY)
67+
def test_use_role_no_current_role(mock_execute_query, mock_cursor):
68+
mock_execute_query.return_value = mock_cursor([(None,)], [])
69+
with SqlExecutor().use_role("new_role"):
70+
pass
71+
assert mock_execute_query.mock_calls == [
72+
mock.call("select current_role()"),
73+
mock.call(f"use role new_role"),
74+
]
75+
76+
6677
@pytest.mark.parametrize(
6778
"new_warehouse, expected_new_warehouse, current_warehouse, expected_current_warehouse",
6879
[
@@ -94,6 +105,20 @@ def test_use_warehouse_different_id(
94105
]
95106

96107

108+
@mock.patch(EXECUTE_QUERY)
109+
def test_use_warehouse_no_current_wh(
110+
mock_execute_query,
111+
mock_cursor,
112+
):
113+
mock_execute_query.return_value = mock_cursor([(None,)], [])
114+
with SqlExecutor().use_warehouse("new_warehouse"):
115+
pass
116+
assert mock_execute_query.mock_calls == [
117+
mock.call("select current_warehouse()"),
118+
mock.call(f"use warehouse new_warehouse"),
119+
]
120+
121+
97122
@pytest.mark.parametrize(
98123
"new_warehouse, current_warehouse",
99124
[

0 commit comments

Comments
 (0)