Skip to content

Commit d9d46b0

Browse files
Pablo Takarahashhar
authored andcommitted
Use OAuth2 if externalAuthentication is present in connection url
After this change if 'externalAuthentication' is passed as a parameter on the connection url we automatically set `http_schema` to `http` and use `OAuth2Authentication`.
1 parent 0584c93 commit d9d46b0

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

tests/unit/sqlalchemy/test_dialect.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
from sqlalchemy.engine.url import URL, make_url
66

7-
from trino.auth import BasicAuthentication
7+
from trino.auth import BasicAuthentication, OAuth2Authentication
88
from trino.dbapi import Connection
99
from trino.sqlalchemy import URL as trino_url
1010
from trino.sqlalchemy.dialect import (
@@ -296,3 +296,12 @@ def test_trino_connection_certificate_auth():
296296
assert isinstance(cparams['auth'], CertificateAuthentication)
297297
assert cparams['auth']._cert == cert
298298
assert cparams['auth']._key == key
299+
300+
301+
def test_trino_connection_oauth2_auth():
302+
dialect = TrinoDialect()
303+
url = make_url('trino://host/?externalAuthentication=true')
304+
_, cparams = dialect.create_connect_args(url)
305+
306+
assert cparams['http_scheme'] == "https"
307+
assert isinstance(cparams['auth'], OAuth2Authentication)

trino/sqlalchemy/dialect.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222

2323
from trino import dbapi as trino_dbapi
2424
from trino import logging
25-
from trino.auth import BasicAuthentication, CertificateAuthentication, JWTAuthentication
25+
from trino.auth import (
26+
BasicAuthentication,
27+
CertificateAuthentication,
28+
JWTAuthentication,
29+
OAuth2Authentication,
30+
)
2631
from trino.dbapi import Cursor
2732
from trino.sqlalchemy import compiler, datatype, error
2833

@@ -113,6 +118,10 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any
113118
kwargs["http_scheme"] = "https"
114119
kwargs["auth"] = CertificateAuthentication(unquote_plus(url.query['cert']), unquote_plus(url.query['key']))
115120

121+
if "externalAuthentication" in url.query:
122+
kwargs["http_scheme"] = "https"
123+
kwargs["auth"] = OAuth2Authentication()
124+
116125
if "source" in url.query:
117126
kwargs["source"] = unquote_plus(url.query["source"])
118127
else:

0 commit comments

Comments
 (0)