Skip to content

Commit 962ca91

Browse files
committed
refactor(rest): extract SigV4 signing into catalog/rest/sigv4.py
1 parent 5da8186 commit 962ca91

2 files changed

Lines changed: 118 additions & 71 deletions

File tree

pyiceberg/catalog/rest/__init__.py

Lines changed: 24 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from typing_extensions import override
3131

3232
from pyiceberg import __version__
33-
from pyiceberg.catalog import BOTOCORE_SESSION, TOKEN, URI, WAREHOUSE_LOCATION, Catalog, PropertiesUpdateSummary
33+
from pyiceberg.catalog import TOKEN, URI, WAREHOUSE_LOCATION, Catalog, PropertiesUpdateSummary
3434
from pyiceberg.catalog.rest.auth import AUTH_MANAGER, AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager
3535
from pyiceberg.catalog.rest.response import _handle_non_200_response
3636
from pyiceberg.catalog.rest.scan_planning import (
@@ -44,6 +44,27 @@
4444
ScanTasks,
4545
StorageCredential,
4646
)
47+
from pyiceberg.catalog.rest.sigv4 import (
48+
EMPTY_BODY_SHA256 as EMPTY_BODY_SHA256,
49+
)
50+
from pyiceberg.catalog.rest.sigv4 import (
51+
SIGV4_MAX_RETRIES as SIGV4_MAX_RETRIES,
52+
)
53+
from pyiceberg.catalog.rest.sigv4 import (
54+
SIGV4_MAX_RETRIES_DEFAULT as SIGV4_MAX_RETRIES_DEFAULT,
55+
)
56+
from pyiceberg.catalog.rest.sigv4 import (
57+
SIGV4_REGION as SIGV4_REGION,
58+
)
59+
from pyiceberg.catalog.rest.sigv4 import (
60+
SIGV4_SERVICE as SIGV4_SERVICE,
61+
)
62+
from pyiceberg.catalog.rest.sigv4 import (
63+
SigV4Adapter as SigV4Adapter,
64+
)
65+
from pyiceberg.catalog.rest.sigv4 import (
66+
init_sigv4,
67+
)
4768
from pyiceberg.exceptions import (
4869
AuthorizationExpiredError,
4970
CommitFailedException,
@@ -60,11 +81,6 @@
6081
ViewAlreadyExistsError,
6182
)
6283
from pyiceberg.io import (
63-
AWS_ACCESS_KEY_ID,
64-
AWS_PROFILE_NAME,
65-
AWS_REGION,
66-
AWS_SECRET_ACCESS_KEY,
67-
AWS_SESSION_TOKEN,
6884
FileIO,
6985
load_file_io,
7086
)
@@ -89,7 +105,7 @@
89105
from pyiceberg.typedef import EMPTY_DICT, UTF8, IcebergBaseModel, Identifier, Properties
90106
from pyiceberg.types import transform_dict_value_to_str
91107
from pyiceberg.utils.deprecated import deprecation_message
92-
from pyiceberg.utils.properties import get_first_property_value, get_header_properties, property_as_bool, property_as_int
108+
from pyiceberg.utils.properties import get_header_properties, property_as_bool, property_as_int
93109
from pyiceberg.view import View
94110
from pyiceberg.view.metadata import ViewMetadata, ViewVersion
95111

@@ -251,11 +267,6 @@ class ScanPlanningMode(Enum):
251267
CA_BUNDLE = "cabundle"
252268
SSL = "ssl"
253269
SIGV4 = "rest.sigv4-enabled"
254-
SIGV4_REGION = "rest.signing-region"
255-
SIGV4_SERVICE = "rest.signing-name"
256-
SIGV4_MAX_RETRIES = "rest.sigv4.max-retries"
257-
SIGV4_MAX_RETRIES_DEFAULT = 10
258-
EMPTY_BODY_SHA256: str = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
259270
OAUTH2_SERVER_URI = "oauth2-server-uri"
260271
SNAPSHOT_LOADING_MODE = "snapshot-loading-mode"
261272
AUTH = "auth"
@@ -456,7 +467,7 @@ def _create_session(self) -> Session:
456467

457468
# Configure SigV4 Request Signing
458469
if property_as_bool(self.properties, SIGV4, False):
459-
self._init_sigv4(session)
470+
init_sigv4(session, self.uri, self.properties)
460471

461472
return session
462473

@@ -761,64 +772,6 @@ def _split_identifier_for_json(self, identifier: str | Identifier) -> dict[str,
761772
identifier_tuple = self._identifier_to_validated_tuple(identifier)
762773
return {"namespace": identifier_tuple[:-1], "name": identifier_tuple[-1]}
763774

764-
def _init_sigv4(self, session: Session) -> None:
765-
from urllib import parse
766-
767-
import boto3
768-
from botocore.auth import SigV4Auth
769-
from botocore.awsrequest import AWSRequest
770-
from requests import PreparedRequest
771-
from requests.adapters import HTTPAdapter
772-
773-
class SigV4Adapter(HTTPAdapter):
774-
def __init__(self, **properties: str):
775-
self._properties = properties
776-
max_retries = property_as_int(self._properties, SIGV4_MAX_RETRIES, SIGV4_MAX_RETRIES_DEFAULT)
777-
super().__init__(max_retries=max_retries)
778-
self._boto_session = boto3.Session(
779-
profile_name=get_first_property_value(self._properties, AWS_PROFILE_NAME),
780-
region_name=get_first_property_value(self._properties, AWS_REGION),
781-
botocore_session=self._properties.get(BOTOCORE_SESSION),
782-
aws_access_key_id=get_first_property_value(self._properties, AWS_ACCESS_KEY_ID),
783-
aws_secret_access_key=get_first_property_value(self._properties, AWS_SECRET_ACCESS_KEY),
784-
aws_session_token=get_first_property_value(self._properties, AWS_SESSION_TOKEN),
785-
)
786-
787-
def add_headers(self, request: PreparedRequest, **kwargs: Any) -> None: # pylint: disable=W0613
788-
credentials = self._boto_session.get_credentials().get_frozen_credentials()
789-
region = self._properties.get(SIGV4_REGION, self._boto_session.region_name)
790-
service = self._properties.get(SIGV4_SERVICE, "execute-api")
791-
792-
url = str(request.url).split("?")[0]
793-
query = str(parse.urlsplit(request.url).query)
794-
params = dict(parse.parse_qsl(query))
795-
796-
# remove the connection header as it will be updated after signing
797-
if "connection" in request.headers:
798-
del request.headers["connection"]
799-
# For empty bodies, explicitly set the content hash header to the SHA256 of an empty string
800-
if not request.body:
801-
request.headers["x-amz-content-sha256"] = EMPTY_BODY_SHA256
802-
803-
aws_request = AWSRequest(
804-
method=request.method, url=url, params=params, data=request.body, headers=dict(request.headers)
805-
)
806-
807-
SigV4Auth(credentials, service, region).add_auth(aws_request)
808-
original_header = request.headers
809-
signed_headers = aws_request.headers
810-
relocated_headers = {}
811-
812-
# relocate headers if there is a conflict with signed headers
813-
for header, value in original_header.items():
814-
if header in signed_headers and signed_headers[header] != value:
815-
relocated_headers[f"Original-{header}"] = value
816-
817-
request.headers.update(relocated_headers)
818-
request.headers.update(signed_headers)
819-
820-
session.mount(self.uri, SigV4Adapter(**self.properties))
821-
822775
def _response_to_table(self, identifier_tuple: tuple[str, ...], table_response: TableResponse) -> Table:
823776
# Per Iceberg spec: storage-credentials take precedence over config
824777
credential_config = self._resolve_storage_credentials(

pyiceberg/catalog/rest/sigv4.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
from typing import Any
20+
from urllib import parse
21+
22+
import boto3
23+
from botocore.auth import SigV4Auth
24+
from botocore.awsrequest import AWSRequest
25+
from requests import PreparedRequest, Session
26+
from requests.adapters import HTTPAdapter
27+
28+
from pyiceberg.catalog import BOTOCORE_SESSION
29+
from pyiceberg.io import (
30+
AWS_ACCESS_KEY_ID,
31+
AWS_PROFILE_NAME,
32+
AWS_REGION,
33+
AWS_SECRET_ACCESS_KEY,
34+
AWS_SESSION_TOKEN,
35+
)
36+
from pyiceberg.utils.properties import get_first_property_value, property_as_int
37+
38+
SIGV4_REGION = "rest.signing-region"
39+
SIGV4_SERVICE = "rest.signing-name"
40+
SIGV4_MAX_RETRIES = "rest.sigv4.max-retries"
41+
SIGV4_MAX_RETRIES_DEFAULT = 10
42+
EMPTY_BODY_SHA256: str = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
43+
44+
45+
class SigV4Adapter(HTTPAdapter):
46+
def __init__(self, **properties: str):
47+
self._properties = properties
48+
max_retries = property_as_int(self._properties, SIGV4_MAX_RETRIES, SIGV4_MAX_RETRIES_DEFAULT)
49+
super().__init__(max_retries=max_retries)
50+
self._boto_session = boto3.Session(
51+
profile_name=get_first_property_value(self._properties, AWS_PROFILE_NAME),
52+
region_name=get_first_property_value(self._properties, AWS_REGION),
53+
botocore_session=self._properties.get(BOTOCORE_SESSION),
54+
aws_access_key_id=get_first_property_value(self._properties, AWS_ACCESS_KEY_ID),
55+
aws_secret_access_key=get_first_property_value(self._properties, AWS_SECRET_ACCESS_KEY),
56+
aws_session_token=get_first_property_value(self._properties, AWS_SESSION_TOKEN),
57+
)
58+
59+
def add_headers(self, request: PreparedRequest, **kwargs: Any) -> None: # pylint: disable=W0613
60+
credentials = self._boto_session.get_credentials().get_frozen_credentials()
61+
region = self._properties.get(SIGV4_REGION, self._boto_session.region_name)
62+
service = self._properties.get(SIGV4_SERVICE, "execute-api")
63+
64+
url = str(request.url).split("?")[0]
65+
query = str(parse.urlsplit(request.url).query)
66+
params = dict(parse.parse_qsl(query))
67+
68+
# remove the connection header as it will be updated after signing
69+
if "connection" in request.headers:
70+
del request.headers["connection"]
71+
# For empty bodies, explicitly set the content hash header to the SHA256 of an empty string
72+
if not request.body:
73+
request.headers["x-amz-content-sha256"] = EMPTY_BODY_SHA256
74+
75+
aws_request = AWSRequest(
76+
method=request.method, url=url, params=params, data=request.body, headers=dict(request.headers)
77+
)
78+
79+
SigV4Auth(credentials, service, region).add_auth(aws_request)
80+
original_header = request.headers
81+
signed_headers = aws_request.headers
82+
relocated_headers = {}
83+
84+
# relocate headers if there is a conflict with signed headers
85+
for header, value in original_header.items():
86+
if header in signed_headers and signed_headers[header] != value:
87+
relocated_headers[f"Original-{header}"] = value
88+
89+
request.headers.update(relocated_headers)
90+
request.headers.update(signed_headers)
91+
92+
93+
def init_sigv4(session: Session, uri: str, properties: dict[str, str]) -> None:
94+
session.mount(uri, SigV4Adapter(**properties))

0 commit comments

Comments
 (0)