@@ -34,15 +34,20 @@ def __init__(
3434 if access_key and secret_key :
3535 # Backwards compatibility: use static keys
3636 self ._credentials = Credentials (access_key , secret_key , token )
37+ self ._has_static_keys = True
38+ self ._session = None
3739 else :
3840 # IRSA
3941 session = boto3 .Session ()
4042 creds = session .get_credentials ()
4143 if not creds :
4244 raise RuntimeError ("No AWS credentials found (neither static keys nor IRSA)" )
4345 self ._credentials = creds
46+ self ._has_static_keys = False
47+ self ._session = session
4448
4549 role_to_assume = assume_role_arn or AWS_ASSUME_ROLE
50+ self ._role_to_assume = role_to_assume
4651 if role_to_assume :
4752 self ._assume_role (role_to_assume )
4853
@@ -93,6 +98,24 @@ def signed_request(
9398 params = params ,
9499 )
95100
101+ def _refresh_credentials (self ) -> None :
102+ """
103+ Boto should automatically refresh expired credentials but when assuming role it cant be done automatically
104+ """
105+ try :
106+ if not self ._has_static_keys and self ._session is not None :
107+ # this is also needed for assume role if base credentials fails
108+ refreshed = self ._session .get_credentials ()
109+ if refreshed :
110+ self ._credentials = refreshed
111+ except Exception :
112+ logging .exception ("Failed to refresh session credentials" )
113+ if self ._role_to_assume :
114+ try :
115+ self ._assume_role (self ._role_to_assume )
116+ except Exception :
117+ logging .exception ("Failed to refresh assume role" )
118+
96119 def _custom_query (self , query : str , params : dict = None ):
97120 """
98121 Send a custom query to a Prometheus Host.
@@ -121,6 +144,16 @@ def _custom_query(self, query: str, params: dict = None):
121144 verify = self .ssl_verification ,
122145 headers = self .headers ,
123146 )
147+ if response is not None and response .status_code == 403 :
148+ self ._refresh_credentials ()
149+ response = self .signed_request (
150+ method = "POST" ,
151+ url = "{0}/api/v1/query" .format (self .url ),
152+ data = {** {"query" : query }, ** params },
153+ params = {},
154+ verify = self .ssl_verification ,
155+ headers = self .headers ,
156+ )
124157 return response
125158
126159 def safe_custom_query_range (
@@ -162,6 +195,18 @@ def safe_custom_query_range(
162195 params = {},
163196 headers = self .headers ,
164197 )
198+ if response is not None and response .status_code == 403 :
199+ self ._refresh_credentials ()
200+ response = self .signed_request (
201+ method = "POST" ,
202+ url = "{0}/api/v1/query_range" .format (self .url ),
203+ data = {
204+ ** {"query" : query , "start" : start , "end" : end , "step" : step },
205+ ** params ,
206+ },
207+ params = {},
208+ headers = self .headers ,
209+ )
165210 if response .status_code == 200 :
166211 return response .json ()["data" ]
167212 else :
0 commit comments