Skip to content

Commit 6ea25a3

Browse files
izeigermanhashhar
authored andcommitted
Fix Update user-related HTTP headers
1 parent 606a033 commit 6ea25a3

File tree

4 files changed

+10
-6
lines changed

4 files changed

+10
-6
lines changed

tests/unit/test_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def assert_headers(headers):
130130
assert headers[constants.HEADER_CATALOG] == catalog
131131
assert headers[constants.HEADER_SCHEMA] == schema
132132
assert headers[constants.HEADER_SOURCE] == source
133-
assert headers[constants.HEADER_USER] == user
134-
assert headers[constants.HEADER_AUTHORIZATION_USER] == authorization_user
133+
assert headers[constants.HEADER_ORIGINAL_USER] == user
134+
assert headers[constants.HEADER_USER] == authorization_user
135135
assert headers[constants.HEADER_SESSION] == ""
136136
assert headers[constants.HEADER_TRANSACTION] is None
137137
assert headers[constants.HEADER_TIMEZONE] == timezone

trino/auth.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
import trino.logging
3636
from trino import exceptions
37+
from trino.constants import HEADER_ORIGINAL_USER
3738
from trino.constants import HEADER_USER
3839
from trino.constants import MAX_NT_PASSWORD_SIZE
3940

@@ -552,7 +553,7 @@ def _determine_host(url: Optional[str]) -> Any:
552553

553554
@staticmethod
554555
def _determine_user(headers: Mapping[Any, Any]) -> Optional[Any]:
555-
return headers.get(HEADER_USER)
556+
return headers.get(HEADER_ORIGINAL_USER, headers.get(HEADER_USER))
556557

557558
@staticmethod
558559
def _construct_cache_key(host: Optional[str], user: Optional[str]) -> Optional[str]:

trino/client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,8 +511,11 @@ def http_headers(self) -> CaseInsensitiveDict[str]:
511511
headers[constants.HEADER_CATALOG] = self._client_session.catalog
512512
headers[constants.HEADER_SCHEMA] = self._client_session.schema
513513
headers[constants.HEADER_SOURCE] = self._client_session.source
514-
headers[constants.HEADER_USER] = self._client_session.user
515-
headers[constants.HEADER_AUTHORIZATION_USER] = self._client_session.authorization_user
514+
if self._client_session.authorization_user is not None:
515+
headers[constants.HEADER_ORIGINAL_USER] = self._client_session.user
516+
headers[constants.HEADER_USER] = self._client_session.authorization_user
517+
else:
518+
headers[constants.HEADER_USER] = self._client_session.user
516519
headers[constants.HEADER_TIMEZONE] = self._client_session.timezone
517520
if self._client_session.encoding is None:
518521
pass

trino/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
HEADER_SCHEMA = "X-Trino-Schema"
3434
HEADER_SOURCE = "X-Trino-Source"
3535
HEADER_USER = "X-Trino-User"
36+
HEADER_ORIGINAL_USER = "X-Trino-Original-User"
3637
HEADER_CLIENT_INFO = "X-Trino-Client-Info"
3738
HEADER_CLIENT_TAGS = "X-Trino-Client-Tags"
3839
HEADER_EXTRA_CREDENTIAL = "X-Trino-Extra-Credential"
@@ -61,7 +62,6 @@
6162
CLIENT_CAPABILITY_SESSION_AUTHORIZATION = "SESSION_AUTHORIZATION"
6263
CLIENT_CAPABILITIES = ','.join([CLIENT_CAPABILITY_PARAMETRIC_DATETIME, CLIENT_CAPABILITY_SESSION_AUTHORIZATION])
6364

64-
HEADER_AUTHORIZATION_USER = "X-Trino-Authorization-User"
6565
HEADER_SET_AUTHORIZATION_USER = "X-Trino-Set-Authorization-User"
6666
HEADER_RESET_AUTHORIZATION_USER = "X-Trino-Reset-Authorization-User"
6767

0 commit comments

Comments
 (0)