Skip to content

Commit ee89c25

Browse files
Snow 2112354 fix empty current object (#2330)
* fix sql-facade * fix sql_executor
1 parent 94bf0f2 commit ee89c25

File tree

5 files changed

+93
-5
lines changed

5 files changed

+93
-5
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
* Added `--encryption` flag to `snow stage create` command defining the type of encryption for all files on the stage.
3232

3333
## Fixes and improvements
34+
* Fix `use` commands error if current database is not set.
3435

3536

3637
# v3.8.3

src/snowflake/cli/_plugins/nativeapp/sf_sql_facade.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,10 @@ def _use_object_optional(self, object_type: UseObjectType, name: str | None):
159159
)
160160

161161
try:
162-
prev_obj = to_identifier(current_obj_result_row[0])
162+
if current_obj_result_row[0]:
163+
prev_obj = to_identifier(current_obj_result_row[0])
164+
else:
165+
prev_obj = None
163166
except IndexError:
164167
prev_obj = None
165168

src/snowflake/cli/api/sql_execution.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,10 @@ def use(self, object_type: ObjectType, name: str):
124124
# Rewrite the error to make the message more useful.
125125
raise CouldNotUseObjectError(object_type=object_type, name=name) from err
126126

127-
def current_role(self) -> str:
128-
return to_identifier(self.execute_query(f"select current_role()").fetchone()[0])
127+
def current_role(self) -> Optional[str]:
128+
if result := self.execute_query(f"select current_role()").fetchone()[0]:
129+
return to_identifier(result)
130+
return None
129131

130132
@contextmanager
131133
def use_role(self, new_role: str):
@@ -135,14 +137,17 @@ def use_role(self, new_role: str):
135137
"""
136138
new_role = to_identifier(new_role)
137139
prev_role = self.current_role()
138-
is_different_role = new_role.lower() != prev_role.lower()
140+
if prev_role:
141+
is_different_role = new_role.lower() != prev_role.lower()
142+
else:
143+
is_different_role = True
139144
if is_different_role:
140145
self._log.debug("Assuming different role: %s", new_role)
141146
self.execute_query(f"use role {new_role}")
142147
try:
143148
yield
144149
finally:
145-
if is_different_role:
150+
if is_different_role and prev_role:
146151
self.execute_query(f"use role {prev_role}")
147152

148153
def session_has_warehouse(self) -> bool:

tests/api/test_sql_execution.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,17 @@ def test_use_role_same_id(mock_execute_query, mock_cursor, new_role, current_rol
6363
assert mock_execute_query.mock_calls == [mock.call("select current_role()")]
6464

6565

66+
@mock.patch(EXECUTE_QUERY)
67+
def test_use_role_no_current_role(mock_execute_query, mock_cursor):
68+
mock_execute_query.return_value = mock_cursor([(None,)], [])
69+
with SqlExecutor().use_role("new_role"):
70+
pass
71+
assert mock_execute_query.mock_calls == [
72+
mock.call("select current_role()"),
73+
mock.call(f"use role new_role"),
74+
]
75+
76+
6677
@pytest.mark.parametrize(
6778
"new_warehouse, expected_new_warehouse, current_warehouse, expected_current_warehouse",
6879
[
@@ -94,6 +105,20 @@ def test_use_warehouse_different_id(
94105
]
95106

96107

108+
@mock.patch(EXECUTE_QUERY)
109+
def test_use_warehouse_no_current_wh(
110+
mock_execute_query,
111+
mock_cursor,
112+
):
113+
mock_execute_query.return_value = mock_cursor([(None,)], [])
114+
with SqlExecutor().use_warehouse("new_warehouse"):
115+
pass
116+
assert mock_execute_query.mock_calls == [
117+
mock.call("select current_warehouse()"),
118+
mock.call(f"use warehouse new_warehouse"),
119+
]
120+
121+
97122
@pytest.mark.parametrize(
98123
"new_warehouse, current_warehouse",
99124
[

tests/nativeapp/test_sf_sql_facade.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,24 @@ def test_use_role_same_id(mock_execute_query, mock_cursor, new_role, current_rol
841841
assert mock_execute_query.mock_calls == expected
842842

843843

844+
def test_use_role_current_role_empty(mock_execute_query, mock_cursor):
845+
side_effects, expected = mock_execute_helper(
846+
[
847+
(
848+
mock_cursor([(None,)], []),
849+
mock.call("select current_role()"),
850+
),
851+
(None, mock.call('use role "test role"')),
852+
]
853+
)
854+
mock_execute_query.side_effect = side_effects
855+
856+
with sql_facade._use_role_optional("test role"): # noqa: SLF001
857+
pass
858+
859+
assert mock_execute_query.mock_calls == expected
860+
861+
844862
@pytest.mark.parametrize(
845863
"old_db, expected_old_db",
846864
[("old_db", "old_db"), ("old db", '"old db"')],
@@ -888,6 +906,24 @@ def test_use_db_same_id(mock_execute_query, mock_cursor, new_db, current_db):
888906
assert mock_execute_query.mock_calls == expected
889907

890908

909+
def test_use_db_current_db_empty(mock_execute_query, mock_cursor):
910+
side_effects, expected = mock_execute_helper(
911+
[
912+
(
913+
mock_cursor([(None,)], []),
914+
mock.call("select current_database()"),
915+
),
916+
(None, mock.call('use database "new db"')),
917+
]
918+
)
919+
mock_execute_query.side_effect = side_effects
920+
921+
with sql_facade._use_database_optional("new db"): # noqa: SLF001
922+
pass
923+
924+
assert mock_execute_query.mock_calls == expected
925+
926+
891927
@pytest.mark.parametrize(
892928
"old_schema, expected_old_schema",
893929
[("old_schema", "old_schema"), ("old schema", '"old schema"')],
@@ -941,6 +977,24 @@ def test_use_schema_same_id(
941977
assert mock_execute_query.mock_calls == expected
942978

943979

980+
def test_use_schema_current_schema_empty(mock_execute_query, mock_cursor):
981+
side_effects, expected = mock_execute_helper(
982+
[
983+
(
984+
mock_cursor([(None,)], []),
985+
mock.call("select current_schema()"),
986+
),
987+
(None, mock.call('use schema "new schema"')),
988+
]
989+
)
990+
mock_execute_query.side_effect = side_effects
991+
992+
with sql_facade._use_schema_optional("new schema"): # noqa: SLF001
993+
pass
994+
995+
assert mock_execute_query.mock_calls == expected
996+
997+
944998
@pytest.mark.parametrize(
945999
"error_raised, error_caught, error_message",
9461000
[

0 commit comments

Comments
 (0)