@@ -34,15 +34,20 @@ def __init__(
3434 if access_key and secret_key :
3535 # Backwards compatibility: use static keys
3636 self ._credentials = Credentials (access_key , secret_key , token )
37+ self ._has_static_keys = True
38+ self ._session = None
3739 else :
3840 # IRSA
3941 session = boto3 .Session ()
4042 creds = session .get_credentials ()
4143 if not creds :
4244 raise RuntimeError ("No AWS credentials found (neither static keys nor IRSA)" )
4345 self ._credentials = creds
46+ self ._has_static_keys = False
47+ self ._session = session
4448
4549 role_to_assume = assume_role_arn or AWS_ASSUME_ROLE
50+ self ._role_to_assume = role_to_assume
4651 if role_to_assume :
4752 self ._assume_role (role_to_assume )
4853
@@ -93,6 +98,45 @@ def signed_request(
9398 params = params ,
9499 )
95100
101+ def _refresh_credentials (self ) -> None :
102+ """
103+ Boto should automatically refresh expired credentials but when assuming role it cant be done automatically
104+ """
105+ try :
106+ if not self ._has_static_keys and self ._session is not None :
107+ # this is also needed for assume role if base credentials fails
108+ refreshed = self ._session .get_credentials ()
109+ if refreshed :
110+ self ._credentials = refreshed
111+ except Exception :
112+ logging .exception ("Failed to refresh session credentials" )
113+ if self ._role_to_assume :
114+ try :
115+ self ._assume_role (self ._role_to_assume )
116+ except Exception :
117+ logging .exception ("Failed to refresh assume role" )
118+
119+ def _request_with_refresh (self , * , method , url , data = None , params = None , headers = None , verify = False ):
120+ resp = self .signed_request (
121+ method = method ,
122+ url = url ,
123+ data = data ,
124+ params = params ,
125+ verify = verify ,
126+ headers = headers ,
127+ )
128+ if resp is not None and resp .status_code in (400 , 401 , 403 ):
129+ self ._refresh_credentials ()
130+ resp = self .signed_request (
131+ method = method ,
132+ url = url ,
133+ data = data ,
134+ params = params ,
135+ verify = verify ,
136+ headers = headers ,
137+ )
138+ return resp
139+
96140 def _custom_query (self , query : str , params : dict = None ):
97141 """
98142 Send a custom query to a Prometheus Host.
@@ -113,7 +157,7 @@ def _custom_query(self, query: str, params: dict = None):
113157 data = None
114158 query = str (query )
115159 # using the query API to get raw data
116- response = self .signed_request (
160+ response = self ._request_with_refresh (
117161 method = "POST" ,
118162 url = "{0}/api/v1/query" .format (self .url ),
119163 data = {** {"query" : query }, ** params },
@@ -152,7 +196,7 @@ def safe_custom_query_range(
152196 params = params or {}
153197
154198 query = str (query )
155- response = self .signed_request (
199+ response = self ._request_with_refresh (
156200 method = "POST" ,
157201 url = "{0}/api/v1/query_range" .format (self .url ),
158202 data = {
0 commit comments