Skip to content

Commit e751241

Browse files
gkarghashhar
authored andcommitted
Add GSSAPIAuthentication authentication class.
1 parent b8ce27d commit e751241

File tree

5 files changed

+152
-8
lines changed

5 files changed

+152
-8
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ jobs:
7676
sudo apt-get update
7777
sudo apt-get install libkrb5-dev
7878
pip install wheel
79-
pip install .[tests] sqlalchemy${{ matrix.sqlalchemy }}
79+
pip install .[tests,gssapi] sqlalchemy${{ matrix.sqlalchemy }}
8080
- name: Run tests
8181
run: |
8282
pytest -s tests/

README.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,43 @@ the [`Kerberos` authentication type](https://trino.io/docs/current/security/kerb
326326
)
327327
```
328328

329+
### GSSAPI authentication
330+
331+
The `GSSAPIAuthentication` class can be used to connect to a Trino cluster configured with
332+
the [`Kerberos` authentication type](https://trino.io/docs/current/security/kerberos.html):
333+
334+
It follows the interface for `KerberosAuthentication`, but is using
335+
[requests-gssapi](https://github.com/pythongssapi/requests-gssapi), instead of [requests-kerberos](https://github.com/requests/requests-kerberos) under the hood.
336+
337+
- DBAPI
338+
339+
```python
340+
from trino.dbapi import connect
341+
from trino.auth import GSSAPIAuthentication
342+
343+
conn = connect(
344+
user="<username>",
345+
auth=GSSAPIAuthentication(...),
346+
http_scheme="https",
347+
...
348+
)
349+
```
350+
351+
- SQLAlchemy
352+
353+
```python
354+
from sqlalchemy import create_engine
355+
from trino.auth import GSSAPIAuthentication
356+
357+
engine = create_engine(
358+
"trino://<username>@<host>:<port>/<catalog>",
359+
connect_args={
360+
"auth": GSSAPIAuthentication(...),
361+
"http_scheme": "https",
362+
}
363+
)
364+
```
365+
329366
## User impersonation
330367

331368
In the case where user who submits the query is not the same as user who authenticates to Trino server (e.g in Superset),

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
readme = f.read()
2828

2929
kerberos_require = ["requests_kerberos"]
30+
gssapi_require = ["requests_kerberos"]
3031
sqlalchemy_require = ["sqlalchemy >= 1.3"]
3132
external_authentication_token_cache_require = ["keyring"]
3233

@@ -86,6 +87,7 @@
8687
extras_require={
8788
"all": all_require,
8889
"kerberos": kerberos_require,
90+
"gssapi": gssapi_require,
8991
"sqlalchemy": sqlalchemy_require,
9092
"tests": tests_require,
9193
"external-authentication-token-cache": external_authentication_token_cache_require,

tests/unit/test_client.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pytest
2323
import requests
2424
from httpretty import httprettified
25+
from requests_gssapi.exceptions import SPNEGOExchangeError
2526
from requests_kerberos.exceptions import KerberosExchangeError
2627
from tzlocal import get_localzone_name # type: ignore
2728

@@ -39,7 +40,7 @@
3940
_post_statement_requests,
4041
)
4142
from trino import __version__, constants
42-
from trino.auth import KerberosAuthentication, _OAuth2TokenBearer
43+
from trino.auth import GSSAPIAuthentication, KerberosAuthentication, _OAuth2TokenBearer
4344
from trino.client import (
4445
ClientSession,
4546
TrinoQuery,
@@ -883,15 +884,22 @@ def retry_count(self):
883884
return self._retry_count
884885

885886

886-
def test_authentication_fail_retry(monkeypatch):
887-
post_retry = RetryRecorder(error=KerberosExchangeError())
887+
@pytest.mark.parametrize(
888+
"auth_method, retry_exception",
889+
[
890+
(KerberosAuthentication, KerberosExchangeError),
891+
(GSSAPIAuthentication, SPNEGOExchangeError),
892+
]
893+
)
894+
def test_authentication_fail_retry(auth_class, retry_exception_class, monkeypatch):
895+
post_retry = RetryRecorder(error=retry_exception_class())
888896
monkeypatch.setattr(TrinoRequest.http.Session, "post", post_retry)
889897

890-
get_retry = RetryRecorder(error=KerberosExchangeError())
898+
get_retry = RetryRecorder(error=retry_exception_class())
891899
monkeypatch.setattr(TrinoRequest.http.Session, "get", get_retry)
892900

893901
attempts = 3
894-
kerberos_auth = KerberosAuthentication()
902+
kerberos_auth = auth_class()
895903
req = TrinoRequest(
896904
host="coordinator",
897905
port=8080,
@@ -903,11 +911,11 @@ def test_authentication_fail_retry(monkeypatch):
903911
max_attempts=attempts,
904912
)
905913

906-
with pytest.raises(KerberosExchangeError):
914+
with pytest.raises(retry_exception_class):
907915
req.post("URL")
908916
assert post_retry.retry_count == attempts
909917

910-
with pytest.raises(KerberosExchangeError):
918+
with pytest.raises(retry_exception_class):
911919
req.get("URL")
912920
assert post_retry.retry_count == attempts
913921

trino/auth.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,103 @@ def __eq__(self, other: object) -> bool:
107107
and self._ca_bundle == other._ca_bundle)
108108

109109

110+
class GSSAPIAuthentication(Authentication):
111+
def __init__(
112+
self,
113+
config: Optional[str] = None,
114+
service_name: Optional[str] = None,
115+
mutual_authentication: bool = False,
116+
force_preemptive: bool = False,
117+
hostname_override: Optional[str] = None,
118+
sanitize_mutual_error_response: bool = True,
119+
principal: Optional[str] = None,
120+
delegate: bool = False,
121+
ca_bundle: Optional[str] = None,
122+
) -> None:
123+
self._config = config
124+
self._service_name = service_name
125+
self._mutual_authentication = mutual_authentication
126+
self._force_preemptive = force_preemptive
127+
self._hostname_override = hostname_override
128+
self._sanitize_mutual_error_response = sanitize_mutual_error_response
129+
self._principal = principal
130+
self._delegate = delegate
131+
self._ca_bundle = ca_bundle
132+
133+
def set_http_session(self, http_session: Session) -> Session:
134+
try:
135+
import requests_gssapi
136+
except ImportError:
137+
raise RuntimeError("unable to import requests_gssapi")
138+
139+
if self._config:
140+
os.environ["KRB5_CONFIG"] = self._config
141+
http_session.trust_env = False
142+
http_session.auth = requests_gssapi.HTTPSPNEGOAuth(
143+
mutual_authentication=self._mutual_authentication,
144+
opportunistic_auth=self._force_preemptive,
145+
target_name=self._get_target_name(self._hostname_override, self._service_name),
146+
sanitize_mutual_error_response=self._sanitize_mutual_error_response,
147+
creds=self._get_credentials(self._principal),
148+
delegate=self._delegate,
149+
)
150+
if self._ca_bundle:
151+
http_session.verify = self._ca_bundle
152+
return http_session
153+
154+
def _get_credentials(self, principal: Optional[str] = None) -> Any:
155+
if principal:
156+
try:
157+
import gssapi
158+
except ImportError:
159+
raise RuntimeError("unable to import gssapi")
160+
161+
name = gssapi.Name(principal, gssapi.NameType.user)
162+
return gssapi.Credentials(name=name, usage="initiate")
163+
164+
return None
165+
166+
def _get_target_name(
167+
self,
168+
hostname_override: Optional[str] = None,
169+
service_name: Optional[str] = None,
170+
) -> Any:
171+
if service_name is not None:
172+
try:
173+
import gssapi
174+
except ImportError:
175+
raise RuntimeError("unable to import gssapi")
176+
177+
if hostname_override is None:
178+
raise ValueError("service name must be used together with hostname_override")
179+
180+
kerb_spn = "{0}@{1}".format(service_name, hostname_override)
181+
return gssapi.Name(kerb_spn, gssapi.NameType.hostbased_service)
182+
183+
return hostname_override
184+
185+
def get_exceptions(self) -> Tuple[Any, ...]:
186+
try:
187+
from requests_gssapi.exceptions import SPNEGOExchangeError
188+
189+
return SPNEGOExchangeError,
190+
except ImportError:
191+
raise RuntimeError("unable to import requests_kerberos")
192+
193+
def __eq__(self, other: object) -> bool:
194+
if not isinstance(other, GSSAPIAuthentication):
195+
return False
196+
return (self._config == other._config
197+
and self._service_name == other._service_name
198+
and self._mutual_authentication == other._mutual_authentication
199+
and self._force_preemptive == other._force_preemptive
200+
and self._hostname_override == other._hostname_override
201+
and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response
202+
and self._principal == other._principal
203+
and self._delegate == other._delegate
204+
and self._ca_bundle == other._ca_bundle)
205+
206+
110207
class BasicAuthentication(Authentication):
111208
def __init__(self, username: str, password: str):
112209
self._username = username

0 commit comments

Comments
 (0)