Skip to content

Commit fc8890b

Browse files
committed
refreshes credentials for irsa
1 parent 110245a commit fc8890b

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

prometrix/connect/aws_connect.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,20 @@ def __init__(
3434
if access_key and secret_key:
3535
# Backwards compatibility: use static keys
3636
self._credentials = Credentials(access_key, secret_key, token)
37+
self._has_static_keys = True
38+
self._session = None
3739
else:
3840
# IRSA
3941
session = boto3.Session()
4042
creds = session.get_credentials()
4143
if not creds:
4244
raise RuntimeError("No AWS credentials found (neither static keys nor IRSA)")
4345
self._credentials = creds
46+
self._has_static_keys = False
47+
self._session = session
4448

4549
role_to_assume = assume_role_arn or AWS_ASSUME_ROLE
50+
self._role_to_assume = role_to_assume
4651
if role_to_assume:
4752
self._assume_role(role_to_assume)
4853

@@ -93,6 +98,24 @@ def signed_request(
9398
params=params,
9499
)
95100

101+
def _refresh_credentials(self) -> None:
102+
"""
103+
Boto should automatically refresh expired credentials but when assuming role it cant be done automatically
104+
"""
105+
try:
106+
if not self._has_static_keys and self._session is not None:
107+
# this is also needed for assume role if base credentials fails
108+
refreshed = self._session.get_credentials()
109+
if refreshed:
110+
self._credentials = refreshed
111+
except Exception:
112+
logging.exception("Failed to refresh session credentials")
113+
if self._role_to_assume:
114+
try:
115+
self._assume_role(self._role_to_assume)
116+
except Exception:
117+
logging.exception("Failed to refresh assume role")
118+
96119
def _custom_query(self, query: str, params: dict = None):
97120
"""
98121
Send a custom query to a Prometheus Host.
@@ -121,6 +144,16 @@ def _custom_query(self, query: str, params: dict = None):
121144
verify=self.ssl_verification,
122145
headers=self.headers,
123146
)
147+
if response is not None and response.status_code == 403:
148+
self._refresh_credentials()
149+
response = self.signed_request(
150+
method="POST",
151+
url="{0}/api/v1/query".format(self.url),
152+
data={**{"query": query}, **params},
153+
params={},
154+
verify=self.ssl_verification,
155+
headers=self.headers,
156+
)
124157
return response
125158

126159
def safe_custom_query_range(
@@ -162,6 +195,18 @@ def safe_custom_query_range(
162195
params={},
163196
headers=self.headers,
164197
)
198+
if response is not None and response.status_code == 403:
199+
self._refresh_credentials()
200+
response = self.signed_request(
201+
method="POST",
202+
url="{0}/api/v1/query_range".format(self.url),
203+
data={
204+
**{"query": query, "start": start, "end": end, "step": step},
205+
**params,
206+
},
207+
params={},
208+
headers=self.headers,
209+
)
165210
if response.status_code == 200:
166211
return response.json()["data"]
167212
else:

0 commit comments

Comments
 (0)