generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 28
Expand file tree
/
Copy pathrequests_signer.py
More file actions
77 lines (64 loc) · 2.11 KB
/
requests_signer.py
File metadata and controls
77 lines (64 loc) · 2.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0
Sample signer using Requests.
"""
import typing
from urllib.parse import urlparse
from aws_sdk_signers import AWSRequest, Field, Fields, SigV4Signer, URI
from requests import PreparedRequest
from requests.auth import AuthBase
if typing.TYPE_CHECKING:
from aws_sdk_signers import AWSCredentialIdentity, SigV4SigningProperties
SIGNING_HEADERS = (
"Authorization",
"Date",
"X-Amz-Date",
"X-Amz-Security-Token",
"X-Amz-Content-SHA256",
)
class SigV4Auth(AuthBase):
"""Attaches SigV4Authentication to the given Request object."""
def __init__(
self,
properties: "SigV4SigningProperties",
identity: "AWSCredentialIdentity",
):
self._properties = properties
self._identity = identity
self._signer = SigV4Signer()
def __eq__(self, other):
return self.properties == getattr(other, "properties", None)
def __ne__(self, other):
return not self == other
def __call__(self, r):
self.sign_request(r)
return r
def sign_request(self, r: PreparedRequest):
request = self.convert_to_awsrequest(r)
signed_request = self._signer.sign(
properties=self._properties,
request=request,
identity=self._identity,
)
for header in SIGNING_HEADERS:
if header in signed_request.fields:
r.headers[header] = signed_request.fields[header].as_string()
return r
def convert_to_awsrequest(self, r: PreparedRequest) -> AWSRequest:
url_parts = urlparse(r.url)
uri = URI(
scheme=url_parts.scheme,
host=url_parts.hostname,
port=url_parts.port,
path=url_parts.path,
query=url_parts.query,
fragment=url_parts.fragment,
)
fields = Fields([Field(name=k, values=[v]) for k, v in r.headers.items()])
return AWSRequest(
destination=uri,
method=r.method,
body=r.body,
fields=fields,
)