Skip to content

Commit 5c83292

Browse files
sfc-gh-fpawlowskisfc-gh-turbaszek
authored andcommitted
[async] Applied #2509 to async code
1 parent 3dd7e40 commit 5c83292

File tree

3 files changed

+45
-10
lines changed

3 files changed

+45
-10
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ async def __open_connection(self):
374374
),
375375
scope=self._oauth_scope,
376376
connection=self,
377+
credentials_in_body=self._oauth_credentials_in_body,
377378
)
378379
elif self._authenticator == PROGRAMMATIC_ACCESS_TOKEN:
379380
self.auth_class = AuthByPAT(self._token)

src/snowflake/connector/aio/auth/_oauth_credentials.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
token_request_url: str,
2828
scope: str,
2929
connection: SnowflakeConnection | None = None,
30+
credentials_in_body: bool = False,
3031
**kwargs,
3132
) -> None:
3233
"""Initializes an instance with OAuth client credentials parameters."""
@@ -41,6 +42,7 @@ def __init__(
4142
token_request_url=token_request_url,
4243
scope=scope,
4344
connection=connection,
45+
credentials_in_body=credentials_in_body,
4446
**kwargs,
4547
)
4648

test/unit/aio/test_auth_oauth_credentials_async.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,15 @@ async def test_auth_oauth_credentials_oauth_type():
2828

2929

3030
@pytest.mark.parametrize(
31-
"authenticator", ["OAUTH_CLIENT_CREDENTIALS", "oauth_client_credentials"]
31+
"authenticator, oauth_credentials_in_body",
32+
[
33+
("OAUTH_CLIENT_CREDENTIALS", True),
34+
("oauth_client_credentials", False),
35+
("Oauth_Client_Credentials", None),
36+
],
3237
)
3338
async def test_oauth_client_credentials_authenticator_is_case_insensitive(
34-
monkeypatch, authenticator
39+
monkeypatch, authenticator, oauth_credentials_in_body
3540
):
3641
"""Test that OAuth client credentials authenticator is case insensitive."""
3742
import snowflake.connector.aio
@@ -55,35 +60,62 @@ async def mock_post_request(self, url, headers, json_body, **kwargs):
5560
)
5661

5762
# Mock the OAuth client credentials token request to avoid making HTTP requests
58-
# Note: We need to mock _request_tokens which is called by the sync prepare() method
59-
def mock_request_tokens(self, **kwargs):
60-
# Simulate successful token retrieval
61-
# Return a tuple directly (not a coroutine) since it's called from sync code
63+
def mock_get_request_token_response(self, connection, fields):
64+
# Return fields to verify they are set correctly in tests
6265
return (
63-
"mock_access_token",
64-
None, # Client credentials doesn't use refresh tokens
66+
str(fields),
67+
None,
6568
)
6669

6770
monkeypatch.setattr(
6871
AuthByOauthCredentials,
69-
"_request_tokens",
70-
mock_request_tokens,
72+
"_get_request_token_response",
73+
mock_get_request_token_response,
7174
)
7275

76+
oauth_credentials_in_body_arg = (
77+
{"oauth_credentials_in_body": oauth_credentials_in_body}
78+
if oauth_credentials_in_body is not None
79+
else {}
80+
)
7381
# Create connection with OAuth client credentials authenticator
7482
conn = snowflake.connector.aio.SnowflakeConnection(
7583
user="testuser",
7684
account="testaccount",
7785
authenticator=authenticator,
7886
oauth_client_id="test_client_id",
7987
oauth_client_secret="test_client_secret",
88+
**oauth_credentials_in_body_arg,
8089
)
8190

8291
await conn.connect()
8392

8493
# Verify that the auth_class is an instance of AuthByOauthCredentials
8594
assert isinstance(conn.auth_class, AuthByOauthCredentials)
8695

96+
# Verify that the credentials_in_body attribute is set correctly
97+
expected_credentials_in_body = (
98+
oauth_credentials_in_body if oauth_credentials_in_body is not None else False
99+
)
100+
assert conn.auth_class._credentials_in_body is expected_credentials_in_body
101+
102+
str_fields, _ = conn.auth_class._request_tokens(
103+
conn=conn,
104+
authenticator=authenticator,
105+
account="<unused-acount>",
106+
user="<unused-user>",
107+
service_name=None,
108+
)
109+
credential_fields = (
110+
", 'client_id': 'test_client_id', 'client_secret': 'test_client_secret'"
111+
if expected_credentials_in_body
112+
else ""
113+
)
114+
assert (
115+
str_fields
116+
== "{'grant_type': 'client_credentials', 'scope': ''" + credential_fields + "}"
117+
)
118+
87119
await conn.close()
88120

89121

0 commit comments

Comments
 (0)