Skip to content

Commit 06adf55

Browse files
author
dakodakov
authored
vdk-trino: add oauth connection (#3417)
Add the ability to authenticate to trino using oAuth Signed-off-by: Dako Dakov <[email protected]>
1 parent 616c471 commit 06adf55

File tree

4 files changed

+181
-66
lines changed

4 files changed

+181
-66
lines changed

projects/vdk-plugins/vdk-trino/src/vdk/plugin/trino/trino_config.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from typing import cast
44
from typing import Optional
55

6+
from vdk.internal.builtin_plugins.config.vdk_config import TEAM_CLIENT_ID
7+
from vdk.internal.builtin_plugins.config.vdk_config import TEAM_CLIENT_SECRET
8+
from vdk.internal.builtin_plugins.config.vdk_config import TEAM_OAUTH_AUTHORIZE_URL
69
from vdk.internal.core.config import Configuration
710

811
TRINO_HOST = "TRINO_HOST"
@@ -15,6 +18,9 @@
1518
TRINO_SSL_VERIFY = "TRINO_SSL_VERIFY"
1619
TRINO_TIMEOUT_SECONDS = "TRINO_TIMEOUT_SECONDS"
1720
TRINO_TEMPLATES_DATA_TO_TARGET_STRATEGY = "TRINO_TEMPLATES_DATA_TO_TARGET_STRATEGY"
21+
TRINO_USE_TEAM_OAUTH = "TRINO_USE_TEAM_OAUTH"
22+
TRINO_RETRIES_ON_ERROR = "TRINO_RETRIES_ON_ERROR"
23+
TRINO_RETRIES_BACKOFF_SECONDS = "TRINO_RETRIES_BACKOFF_SECONDS"
1824

1925
trino_templates_data_to_target_strategy: str = ""
2026

@@ -121,6 +127,50 @@ def templates_data_to_target_strategy(self, section: Optional[str]) -> str:
121127
else "INSERT_SELECT",
122128
)
123129

130+
def use_team_oauth(self, section: Optional[str]) -> bool:
131+
return (
132+
parse_boolean(
133+
self.__config.get_value(key=TRINO_USE_TEAM_OAUTH, section=section)
134+
)
135+
if (
136+
self.__config.get_value(key=TRINO_USE_TEAM_OAUTH, section=section)
137+
is not None
138+
)
139+
else False
140+
)
141+
142+
def retries(self, section: Optional[str]) -> int:
143+
return cast(
144+
int, self.__config.get_value(key=TRINO_RETRIES_ON_ERROR, section=section)
145+
)
146+
147+
def backoff_interval_seconds(self, section: Optional[str]) -> int:
148+
return cast(
149+
int,
150+
self.__config.get_value(key=TRINO_RETRIES_BACKOFF_SECONDS, section=section),
151+
)
152+
153+
def team_client_id(self) -> str:
154+
return (
155+
cast(str, self.__config.get_value(key=TEAM_CLIENT_ID))
156+
if self.__config.get_value(key=TEAM_CLIENT_ID) is not None
157+
else None
158+
)
159+
160+
def team_client_secret(self) -> str:
161+
return (
162+
cast(str, self.__config.get_value(key=TEAM_CLIENT_SECRET))
163+
if self.__config.get_value(key=TEAM_CLIENT_SECRET) is not None
164+
else None
165+
)
166+
167+
def team_oauth_url(self) -> str:
168+
return (
169+
cast(str, self.__config.get_value(key=TEAM_OAUTH_AUTHORIZE_URL))
170+
if self.__config.get_value(key=TEAM_OAUTH_AUTHORIZE_URL) is not None
171+
else None
172+
)
173+
124174
@staticmethod
125175
def add_definitions(config_builder):
126176
"""
@@ -172,3 +222,18 @@ def add_definitions(config_builder):
172222
default_value=None,
173223
description="The trino query timeout in seconds.",
174224
)
225+
config_builder.add(
226+
key=TRINO_USE_TEAM_OAUTH,
227+
default_value=False,
228+
description="Should the connection use the team's oAuth credentials to connect to the DBs",
229+
)
230+
config_builder.add(
231+
key=TRINO_RETRIES_ON_ERROR,
232+
default_value=3,
233+
description="The number of times the Trino plugin is going to retry a failing operation",
234+
)
235+
config_builder.add(
236+
key=TRINO_RETRIES_BACKOFF_SECONDS,
237+
default_value=30,
238+
description="The backoff time in seconds between retries of a failed operation",
239+
)

projects/vdk-plugins/vdk-trino/src/vdk/plugin/trino/trino_connection.py

Lines changed: 92 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
11
# Copyright 2023-2024 Broadcom
22
# SPDX-License-Identifier: Apache-2.0
3+
import base64
4+
import json
35
import logging
6+
from typing import Optional
47

8+
import requests
59
from tenacity import before_sleep_log
610
from tenacity import retry
711
from tenacity import retry_if_exception_type
812
from tenacity import stop_after_attempt
913
from tenacity import wait_exponential
14+
from trino import constants
15+
from trino import dbapi
16+
from trino.auth import BasicAuthentication
1017
from trino.exceptions import TrinoExternalError
1118
from trino.exceptions import TrinoInternalError
1219
from vdk.internal.builtin_plugins.connection.managed_connection_base import (
1320
ManagedConnectionBase,
1421
)
1522
from vdk.internal.builtin_plugins.connection.recovery_cursor import RecoveryCursor
1623
from vdk.internal.util.decorators import closing_noexcept_on_close
24+
from vdk.plugin.trino.trino_config import TrinoConfiguration
1725
from vdk.plugin.trino.trino_error_handler import TrinoErrorHandler
1826

1927
log = logging.getLogger(__name__)
@@ -22,18 +30,9 @@
2230
class TrinoConnection(ManagedConnectionBase):
2331
def __init__(
2432
self,
25-
host,
26-
port,
27-
catalog,
28-
schema,
29-
user,
30-
password,
31-
use_ssl=True,
32-
ssl_verify=True,
33-
timeout_seconds=120,
33+
configuration: TrinoConfiguration,
34+
section: Optional[str],
3435
lineage_logger=None,
35-
retries_on_error=3,
36-
error_backoff_seconds=30,
3736
):
3837
"""
3938
Create a new database connection. Connection parameters are:
@@ -42,36 +41,93 @@ def __init__(
4241
- *port*: connection port number (defaults to 8080 if not provided)
4342
- *catalog*: the catalog name (only as keyword argument)
4443
- *schema*: the schema name (only as keyword argument)
45-
- *user*: user name used to authenticate
44+
- *user*: username used to authenticate
4645
"""
4746
super().__init__(logging.getLogger(__name__))
4847

49-
self._host = host
50-
self._port = port
51-
self._catalog = catalog
52-
self._schema = schema
53-
self._user = user
54-
self._password = password
55-
self._use_ssl = use_ssl
56-
self._ssl_verify = ssl_verify
57-
self._timeout_seconds = timeout_seconds
48+
self._host = configuration.host(section)
49+
self._port = configuration.port(section)
50+
self._catalog = configuration.catalog(section)
51+
self._schema = configuration.schema(section)
52+
self._user = configuration.user(section)
53+
self._password = configuration.password(section)
54+
self._use_ssl = configuration.use_ssl(section)
55+
self._ssl_verify = configuration.ssl_verify(section)
56+
self._timeout_seconds = configuration.timeout_seconds(section)
57+
self._retries_on_error = configuration.retries(section)
58+
self._error_backoff_seconds = configuration.backoff_interval_seconds(section)
59+
5860
self._lineage_logger = lineage_logger
59-
self._retries_on_error = retries_on_error
60-
self._error_backoff_seconds = error_backoff_seconds
61+
62+
self._use_team_oauth = configuration.use_team_oauth(section)
63+
64+
if self._use_team_oauth:
65+
self._team_client_id = configuration.team_client_id()
66+
self._team_client_secret = configuration.team_client_secret()
67+
self._team_oauth_url = configuration.team_oauth_url()
68+
log.debug(
69+
f"Creating new trino connection for oAuth ClientID: {self._team_client_id} to host: {self._host}:{self._port}"
70+
)
71+
else:
72+
log.debug(
73+
f"Creating new trino connection for user: {self._user} to host: {self._host}:{self._port}"
74+
)
75+
76+
def _connect(self):
77+
if self._use_team_oauth:
78+
return self._team_oauth_connection()
79+
else:
80+
return self._basic_authentication_connection()
81+
82+
def _team_oauth_connection(self):
6183
log.debug(
62-
f"Creating new trino connection for user: {user} to host: {host}:{port}"
84+
f"Open Trino Connection: host: {self._host}:{self._port} with oAuth ClientID: {self._team_client_id}; "
85+
f"catalog: {self._catalog}; schema: {self._schema}; timeout: {self._timeout_seconds}"
6386
)
6487

65-
def _connect(self):
66-
from trino import dbapi
67-
from trino import constants
88+
oauth_token = self._get_access_token()
89+
90+
# Create an OAuth session
91+
session = OAuthSession(oauth_token)
6892

93+
connection = dbapi.connect(
94+
host=self._host,
95+
port=self._port,
96+
catalog=self._catalog,
97+
schema=self._schema,
98+
http_scheme=constants.HTTPS if self._use_ssl else constants.HTTP,
99+
verify=self._ssl_verify,
100+
request_timeout=self._timeout_seconds,
101+
http_session=session,
102+
)
103+
return connection
104+
105+
def _get_access_token(self):
106+
# Exchange client ID & Secret for an access token
107+
# Original basic auth string
108+
original_string = self._team_client_id + ":" + self._team_client_secret
109+
# Encode
110+
encoded_bytes = base64.b64encode(original_string.encode("utf-8"))
111+
encoded_string = encoded_bytes.decode("utf-8")
112+
113+
headers = {
114+
"Authorization": "Basic " + encoded_string,
115+
"Content-Type": "application/x-www-form-urlencoded",
116+
}
117+
data = {"grant_type": "client_credentials"}
118+
response = requests.post(self._team_oauth_url, headers=headers, data=data)
119+
# If this call fails then, we better raise it as early as possible
120+
response.raise_for_status()
121+
122+
response_json = json.loads(response.text)
123+
oauth_token = response_json["access_token"]
124+
return oauth_token
125+
126+
def _basic_authentication_connection(self):
69127
log.debug(
70128
f"Open Trino Connection: host: {self._host}:{self._port} with user: {self._user}; "
71129
f"catalog: {self._catalog}; schema: {self._schema}; timeout: {self._timeout_seconds}"
72130
)
73-
from trino.auth import BasicAuthentication
74-
75131
auth = (
76132
BasicAuthentication(self._user, self._password) if self._password else None
77133
)
@@ -164,3 +220,10 @@ def _get_lineage_data(self, query):
164220
return None
165221

166222
return None
223+
224+
225+
# Define a custom requests session to add the OAuth token to the headers
226+
class OAuthSession(requests.Session):
227+
def __init__(self, token):
228+
super().__init__()
229+
self.headers.update({"Authorization": f"Bearer {token}"})

projects/vdk-plugins/vdk-trino/src/vdk/plugin/trino/trino_plugin.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -112,29 +112,15 @@ def initialize_job(self, context: JobContext):
112112
port = trino_conf.port(section)
113113

114114
if host and port:
115-
schema = trino_conf.schema(section)
116-
catalog = trino_conf.catalog(section)
117-
user = trino_conf.user(section)
118-
password = trino_conf.password(section)
119-
use_ssl = trino_conf.use_ssl(section)
120-
ssl_verify = trino_conf.ssl_verify(section)
121-
timeout_seconds = trino_conf.timeout_seconds(section)
122115
lineage_logger = context.core_context.state.get(LINEAGE_LOGGER_KEY)
123116
log.info(
124117
f"Creating new Trino connection with name {connection_name} and host {host}"
125118
)
126119
context.connections.add_open_connection_factory_method(
127120
connection_name.lower(),
128-
lambda t_host=host, t_port=port, t_schema=schema, t_catalog=catalog, t_user=user, t_password=password, t_use_ssl=use_ssl, t_ssl_verify=ssl_verify, t_timeout=timeout_seconds, t_lineage_logger=lineage_logger: TrinoConnection(
129-
host=t_host,
130-
port=t_port,
131-
schema=t_schema,
132-
catalog=t_catalog,
133-
user=t_user,
134-
password=t_password,
135-
use_ssl=t_use_ssl,
136-
ssl_verify=t_ssl_verify,
137-
timeout_seconds=t_timeout,
121+
lambda t_configuration=trino_conf, t_section=section, t_lineage_logger=lineage_logger: TrinoConnection(
122+
configuration=t_configuration,
123+
section=t_section,
138124
lineage_logger=t_lineage_logger,
139125
),
140126
)
@@ -157,7 +143,7 @@ def initialize_job(self, context: JobContext):
157143
)
158144
except Exception as e:
159145
raise Exception(
160-
"An error occurred while trying to create new Trino connections and ingesters."
146+
f"An error occurred while trying to create new Trino connections and ingesters for connection:{connection_name}."
161147
f"ERROR: {e}"
162148
)
163149

@@ -205,17 +191,10 @@ def vdk_start(plugin_registry: IPluginRegistry, command_line_args: List):
205191
@click.option("-q", "--query", type=click.STRING, required=True)
206192
@click.pass_context
207193
def trino_query(ctx: click.Context, query):
208-
conf = ctx.obj.configuration
194+
trino_conf = TrinoConfiguration(ctx.obj.configuration)
209195
conn = TrinoConnection(
210-
host=conf.get_value(TRINO_HOST),
211-
port=conf.get_value(TRINO_PORT),
212-
schema=conf.get_value(TRINO_SCHEMA),
213-
catalog=conf.get_value(TRINO_CATALOG),
214-
user=conf.get_value(TRINO_USER),
215-
password=conf.get_value(TRINO_PASSWORD),
216-
use_ssl=conf.get_value(TRINO_USE_SSL),
217-
ssl_verify=conf.get_value(TRINO_SSL_VERIFY),
218-
timeout_seconds=conf.get_value(TRINO_TIMEOUT_SECONDS),
196+
configuration=trino_conf,
197+
section=None,
219198
lineage_logger=ctx.obj.state.get(LINEAGE_LOGGER_KEY),
220199
)
221200

projects/vdk-plugins/vdk-trino/tests/test_trino_multiple_db.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
from unittest import TestCase
88

99
import pytest
10+
from vdk.internal.core.config import ConfigurationBuilder
1011
from vdk.plugin.test_utils.util_funcs import cli_assert_equal
1112
from vdk.plugin.test_utils.util_funcs import CliEntryBasedTestRunner
1213
from vdk.plugin.test_utils.util_funcs import get_test_job_path
1314
from vdk.plugin.trino import trino_plugin
15+
from vdk.plugin.trino.trino_config import TrinoConfiguration
1416
from vdk.plugin.trino.trino_connection import TrinoConnection
1517

1618
VDK_DB_DEFAULT_TYPE = "VDK_DB_DEFAULT_TYPE"
@@ -61,16 +63,22 @@ def test_ingest_to_multiple_trino(self):
6163
)
6264

6365
# check secondary db
66+
builder = ConfigurationBuilder()
67+
builder.add("TRINO_HOST", "localhost")
68+
builder.add("TRINO_PORT", 8081)
69+
builder.add("TRINO_SCHEMA", "default")
70+
builder.add("TRINO_CATALOG", "memory")
71+
builder.add("TRINO_USER", "unknown")
72+
builder.add("TRINO_PASSWORD", None)
73+
builder.add("TRINO_USE_SSL", False)
74+
builder.add("TRINO_SSL_VERIFY", True)
75+
builder.add("TRINO_TIMEOUT_SECONDS", None)
76+
cfg = builder.build()
77+
78+
trino_conf = TrinoConfiguration(cfg)
6479
conn = TrinoConnection(
65-
host="localhost",
66-
port=8081,
67-
schema="default",
68-
catalog="memory", # default
69-
user="unknown", # default
70-
password=None,
71-
use_ssl=False,
72-
ssl_verify=True, # default
73-
timeout_seconds=None,
80+
configuration=trino_conf,
81+
section=None,
7482
lineage_logger=None,
7583
)
7684

0 commit comments

Comments
 (0)