@@ -841,6 +841,24 @@ def test_use_role_same_id(mock_execute_query, mock_cursor, new_role, current_rol
841
841
assert mock_execute_query .mock_calls == expected
842
842
843
843
844
+ def test_use_role_current_role_empty (mock_execute_query , mock_cursor ):
845
+ side_effects , expected = mock_execute_helper (
846
+ [
847
+ (
848
+ mock_cursor ([(None ,)], []),
849
+ mock .call ("select current_role()" ),
850
+ ),
851
+ (None , mock .call ('use role "test role"' )),
852
+ ]
853
+ )
854
+ mock_execute_query .side_effect = side_effects
855
+
856
+ with sql_facade ._use_role_optional ("test role" ): # noqa: SLF001
857
+ pass
858
+
859
+ assert mock_execute_query .mock_calls == expected
860
+
861
+
844
862
@pytest .mark .parametrize (
845
863
"old_db, expected_old_db" ,
846
864
[("old_db" , "old_db" ), ("old db" , '"old db"' )],
@@ -888,6 +906,24 @@ def test_use_db_same_id(mock_execute_query, mock_cursor, new_db, current_db):
888
906
assert mock_execute_query .mock_calls == expected
889
907
890
908
909
+ def test_use_db_current_db_empty (mock_execute_query , mock_cursor ):
910
+ side_effects , expected = mock_execute_helper (
911
+ [
912
+ (
913
+ mock_cursor ([(None ,)], []),
914
+ mock .call ("select current_database()" ),
915
+ ),
916
+ (None , mock .call ('use database "new db"' )),
917
+ ]
918
+ )
919
+ mock_execute_query .side_effect = side_effects
920
+
921
+ with sql_facade ._use_database_optional ("new db" ): # noqa: SLF001
922
+ pass
923
+
924
+ assert mock_execute_query .mock_calls == expected
925
+
926
+
891
927
@pytest .mark .parametrize (
892
928
"old_schema, expected_old_schema" ,
893
929
[("old_schema" , "old_schema" ), ("old schema" , '"old schema"' )],
@@ -941,6 +977,24 @@ def test_use_schema_same_id(
941
977
assert mock_execute_query .mock_calls == expected
942
978
943
979
980
+ def test_use_schema_current_schema_empty (mock_execute_query , mock_cursor ):
981
+ side_effects , expected = mock_execute_helper (
982
+ [
983
+ (
984
+ mock_cursor ([(None ,)], []),
985
+ mock .call ("select current_schema()" ),
986
+ ),
987
+ (None , mock .call ('use schema "new schema"' )),
988
+ ]
989
+ )
990
+ mock_execute_query .side_effect = side_effects
991
+
992
+ with sql_facade ._use_schema_optional ("new schema" ): # noqa: SLF001
993
+ pass
994
+
995
+ assert mock_execute_query .mock_calls == expected
996
+
997
+
944
998
@pytest .mark .parametrize (
945
999
"error_raised, error_caught, error_message" ,
946
1000
[
0 commit comments