11# Copyright 2023-2024 Broadcom
22# SPDX-License-Identifier: Apache-2.0
3+ import base64
4+ import json
35import logging
6+ from typing import Optional
47
8+ import requests
59from tenacity import before_sleep_log
610from tenacity import retry
711from tenacity import retry_if_exception_type
812from tenacity import stop_after_attempt
913from tenacity import wait_exponential
14+ from trino import constants
15+ from trino import dbapi
16+ from trino .auth import BasicAuthentication
1017from trino .exceptions import TrinoExternalError
1118from trino .exceptions import TrinoInternalError
1219from vdk .internal .builtin_plugins .connection .managed_connection_base import (
1320 ManagedConnectionBase ,
1421)
1522from vdk .internal .builtin_plugins .connection .recovery_cursor import RecoveryCursor
1623from vdk .internal .util .decorators import closing_noexcept_on_close
24+ from vdk .plugin .trino .trino_config import TrinoConfiguration
1725from vdk .plugin .trino .trino_error_handler import TrinoErrorHandler
1826
1927log = logging .getLogger (__name__ )
2230class TrinoConnection (ManagedConnectionBase ):
2331 def __init__ (
2432 self ,
25- host ,
26- port ,
27- catalog ,
28- schema ,
29- user ,
30- password ,
31- use_ssl = True ,
32- ssl_verify = True ,
33- timeout_seconds = 120 ,
33+ configuration : TrinoConfiguration ,
34+ section : Optional [str ],
3435 lineage_logger = None ,
35- retries_on_error = 3 ,
36- error_backoff_seconds = 30 ,
3736 ):
3837 """
3938 Create a new database connection. Connection parameters are:
@@ -42,36 +41,93 @@ def __init__(
4241 - *port*: connection port number (defaults to 8080 if not provided)
4342 - *catalog*: the catalog name (only as keyword argument)
4443 - *schema*: the schema name (only as keyword argument)
45- - *user*: user name used to authenticate
44+ - *user*: username used to authenticate
4645 """
4746 super ().__init__ (logging .getLogger (__name__ ))
4847
49- self ._host = host
50- self ._port = port
51- self ._catalog = catalog
52- self ._schema = schema
53- self ._user = user
54- self ._password = password
55- self ._use_ssl = use_ssl
56- self ._ssl_verify = ssl_verify
57- self ._timeout_seconds = timeout_seconds
48+ self ._host = configuration .host (section )
49+ self ._port = configuration .port (section )
50+ self ._catalog = configuration .catalog (section )
51+ self ._schema = configuration .schema (section )
52+ self ._user = configuration .user (section )
53+ self ._password = configuration .password (section )
54+ self ._use_ssl = configuration .use_ssl (section )
55+ self ._ssl_verify = configuration .ssl_verify (section )
56+ self ._timeout_seconds = configuration .timeout_seconds (section )
57+ self ._retries_on_error = configuration .retries (section )
58+ self ._error_backoff_seconds = configuration .backoff_interval_seconds (section )
59+
5860 self ._lineage_logger = lineage_logger
59- self ._retries_on_error = retries_on_error
60- self ._error_backoff_seconds = error_backoff_seconds
61+
62+ self ._use_team_oauth = configuration .use_team_oauth (section )
63+
64+ if self ._use_team_oauth :
65+ self ._team_client_id = configuration .team_client_id ()
66+ self ._team_client_secret = configuration .team_client_secret ()
67+ self ._team_oauth_url = configuration .team_oauth_url ()
68+ log .debug (
69+ f"Creating new trino connection for oAuth ClientID: { self ._team_client_id } to host: { self ._host } :{ self ._port } "
70+ )
71+ else :
72+ log .debug (
73+ f"Creating new trino connection for user: { self ._user } to host: { self ._host } :{ self ._port } "
74+ )
75+
76+ def _connect (self ):
77+ if self ._use_team_oauth :
78+ return self ._team_oauth_connection ()
79+ else :
80+ return self ._basic_authentication_connection ()
81+
82+ def _team_oauth_connection (self ):
6183 log .debug (
62- f"Creating new trino connection for user: { user } to host: { host } :{ port } "
84+ f"Open Trino Connection: host: { self ._host } :{ self ._port } with oAuth ClientID: { self ._team_client_id } ; "
85+ f"catalog: { self ._catalog } ; schema: { self ._schema } ; timeout: { self ._timeout_seconds } "
6386 )
6487
65- def _connect (self ):
66- from trino import dbapi
67- from trino import constants
88+ oauth_token = self ._get_access_token ()
89+
90+ # Create an OAuth session
91+ session = OAuthSession (oauth_token )
6892
93+ connection = dbapi .connect (
94+ host = self ._host ,
95+ port = self ._port ,
96+ catalog = self ._catalog ,
97+ schema = self ._schema ,
98+ http_scheme = constants .HTTPS if self ._use_ssl else constants .HTTP ,
99+ verify = self ._ssl_verify ,
100+ request_timeout = self ._timeout_seconds ,
101+ http_session = session ,
102+ )
103+ return connection
104+
105+ def _get_access_token (self ):
106+ # Exchange client ID & Secret for an access token
107+ # Original basic auth string
108+ original_string = self ._team_client_id + ":" + self ._team_client_secret
109+ # Encode
110+ encoded_bytes = base64 .b64encode (original_string .encode ("utf-8" ))
111+ encoded_string = encoded_bytes .decode ("utf-8" )
112+
113+ headers = {
114+ "Authorization" : "Basic " + encoded_string ,
115+ "Content-Type" : "application/x-www-form-urlencoded" ,
116+ }
117+ data = {"grant_type" : "client_credentials" }
118+ response = requests .post (self ._team_oauth_url , headers = headers , data = data )
119+ # If this call fails then, we better raise it as early as possible
120+ response .raise_for_status ()
121+
122+ response_json = json .loads (response .text )
123+ oauth_token = response_json ["access_token" ]
124+ return oauth_token
125+
126+ def _basic_authentication_connection (self ):
69127 log .debug (
70128 f"Open Trino Connection: host: { self ._host } :{ self ._port } with user: { self ._user } ; "
71129 f"catalog: { self ._catalog } ; schema: { self ._schema } ; timeout: { self ._timeout_seconds } "
72130 )
73- from trino .auth import BasicAuthentication
74-
75131 auth = (
76132 BasicAuthentication (self ._user , self ._password ) if self ._password else None
77133 )
@@ -164,3 +220,10 @@ def _get_lineage_data(self, query):
164220 return None
165221
166222 return None
223+
224+
225+ # Define a custom requests session to add the OAuth token to the headers
226+ class OAuthSession (requests .Session ):
227+ def __init__ (self , token ):
228+ super ().__init__ ()
229+ self .headers .update ({"Authorization" : f"Bearer { token } " })
0 commit comments