Skip to content

Commit b81abe2

Browse files
sdktjjAndrewShakinovsky-SAS
authored andcommitted
To avoid the risk of HTTP-call's running indefinitely, optional parameter http_timeout can now be specified when creating a task. Default timeout is set to (30.05, 300) and is applied to all HTTP-calls.
1 parent 90e58f3 commit b81abe2

File tree

3 files changed

+55
-42
lines changed

3 files changed

+55
-42
lines changed

src/sas_airflow_provider/hooks/sas.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import urllib3
88
from urllib3.exceptions import InsecureRequestWarning
99

10-
1110
class SasHook(BaseHook):
1211
"""Hook to manage connection to SAS"""
1312

@@ -29,7 +28,7 @@ def __init__(self, conn_id: str = None) -> None:
2928
self.cert_verify = True
3029
self.grant_type = None
3130

32-
def get_conn(self):
31+
def get_conn(self, http_timeout=None):
3332
"""Returns a SAS connection."""
3433
if self.conn_id is None:
3534
self.conn_id = self.default_conn_name
@@ -55,11 +54,11 @@ def get_conn(self):
5554
self.log.info("Using custom TLS CA certificate bundle file")
5655

5756
if not self.sas_conn:
58-
self.sas_conn = self._create_session_for_connection()
57+
self.sas_conn = self._create_session_for_connection(http_timeout=http_timeout)
5958

6059
return self.sas_conn
6160

62-
def _create_session_for_connection(self):
61+
def _create_session_for_connection(self, http_timeout=None):
6362
self.log.info(f"Creating session for connection named %s to host %s",
6463
self.conn_id,
6564
self.host)
@@ -82,12 +81,15 @@ def _create_session_for_connection(self):
8281

8382
self.log.info("Get oauth token (see README if this crashes)")
8483
response = requests.post(
85-
f"{self.host}/SASLogon/oauth/token", data=payload, verify=self.cert_verify, headers=my_headers
84+
f"{self.host}/SASLogon/oauth/token",
85+
data=payload,
86+
verify=self.cert_verify,
87+
headers=my_headers,
88+
timeout=http_timeout
8689
)
87-
8890
if response.status_code != 200:
89-
raise RuntimeError(f"Get token failed: {response.text}")
90-
91+
raise RuntimeError(f"Get token failed with status code: {response.status_code}")
92+
9193
r = response.json()
9294
self.token = r["access_token"]
9395

src/sas_airflow_provider/operators/sas_studio.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from airflow.exceptions import AirflowException
2525
from airflow.models import BaseOperator
2626
from sas_airflow_provider.hooks.sas import SasHook
27-
from sas_airflow_provider.util.util import dump_logs, stream_log, create_or_connect_to_session, end_compute_session
27+
from sas_airflow_provider.util.util import stream_log, create_or_connect_to_session, end_compute_session
2828

2929
# main API URI for Code Gen
3030
URI_BASE = "/studioDevelopment/code"
@@ -88,6 +88,8 @@ class SASStudioOperator(BaseOperator):
8888
temporary unobtainable. When unknown_state_timeout is reached without the state being retrievable, the operator
8989
will throw an AirflowFailException and the task will be marked as failed.
9090
Default value is 0, meaning the task will fail immediately if the state could not be retrieved.
91+
:para http_timeout: (optional) Timeout for https requests. Default value is (30.05, 300), meaning a connect timeout sligthly above 30 seoconds and
92+
a read timeout of 300 seconds where the operator will wait for the server to send a response.
9193
"""
9294

9395
ui_color = "#CCE5FF"
@@ -113,6 +115,7 @@ def __init__(
113115
compute_session_id="",
114116
output_macro_var_prefix="",
115117
unknown_state_timeout=0,
118+
http_timeout=(30.05, 300),
116119
**kwargs,
117120
) -> None:
118121

@@ -146,6 +149,9 @@ def __init__(
146149
self.on_failure_callback=[on_failure]
147150
self.on_retry_callback=[on_retry]
148151

152+
# Timeout
153+
self.http_timeout=http_timeout
154+
149155

150156
def execute(self, context):
151157
if self.path_type not in ['compute', 'content', 'raw']:
@@ -158,13 +164,14 @@ def execute(self, context):
158164
try:
159165
self.log.info("Authenticate connection")
160166
h = SasHook(self.connection_name)
161-
self.connection = h.get_conn()
167+
self.connection = h.get_conn(http_timeout=self.http_timeout)
162168

163169
# Create compute session
164170
if not self.compute_session_id:
165171
compute_session = create_or_connect_to_session(self.connection,
166172
self.compute_context_name,
167-
AIRFLOW_SESSION_NAME if self.allways_reuse_session else None)
173+
AIRFLOW_SESSION_NAME if self.allways_reuse_session else None,
174+
http_timeout=self.http_timeout)
168175
self.compute_session_id = compute_session["id"]
169176
else:
170177
self.log.info(f"Compute Session {self.compute_session_id} was provided")
@@ -240,7 +247,7 @@ def _clean_up(self, also_kill_reused_session=False):
240247
if (also_kill_reused_session and self.allways_reuse_session) or self.allways_reuse_session==False:
241248
try:
242249
self.log.info(f"Deleting session with id {self.compute_session_id}")
243-
success_end = end_compute_session(self.connection, self.compute_session_id)
250+
success_end = end_compute_session(self.connection, self.compute_session_id, http_timeout=self.http_timeout)
244251
if success_end:
245252
self.log.info(f"Compute session succesfully deleted")
246253
else:
@@ -300,7 +307,7 @@ def _generate_object_code(self):
300307
"wrapperCode": self.codegen_wrap_code,
301308
}
302309

303-
response = self.connection.post(uri, json=req)
310+
response = self.connection.post(uri, json=req, timeout=self.http_timeout)
304311
if not response.ok:
305312
raise RuntimeError(f"Code generation failed: {response.text}")
306313

@@ -311,7 +318,7 @@ def _run_job_and_wait(self, job_request: dict, poll_interval: int) -> (dict, boo
311318

312319
#Kick off job request. if failures, no harm is done.
313320
try:
314-
response = self.connection.post(uri, json=job_request)
321+
response = self.connection.post(uri, json=job_request, timeout=self.http_timeout)
315322
except Exception as e:
316323
raise AirflowException(f"Error when creating Job Request {e}")
317324

@@ -336,7 +343,7 @@ def _run_job_and_wait(self, job_request: dict, poll_interval: int) -> (dict, boo
336343
time.sleep(poll_interval)
337344

338345
try:
339-
response = self.connection.get(uri)
346+
response = self.connection.get(uri, timeout=self.http_timeout)
340347
if not response.ok:
341348
countUnknownState = countUnknownState + 1
342349
self.log.info(f'Invalid response code {response.status_code} from {uri}. Will set state=unknown and continue checking...')
@@ -352,7 +359,7 @@ def _run_job_and_wait(self, job_request: dict, poll_interval: int) -> (dict, boo
352359

353360
# Get the latest new log lines.
354361
if self.exec_log and state != "unknown":
355-
num_log_lines=stream_log(self.connection, job, num_log_lines)
362+
num_log_lines=stream_log(self.connection, job, num_log_lines, http_timeout=self.http_timeout)
356363

357364
except Exception as e:
358365
countUnknownState = countUnknownState + 1
@@ -365,7 +372,7 @@ def _run_job_and_wait(self, job_request: dict, poll_interval: int) -> (dict, boo
365372

366373
# Be sure to Get the latest new log lines after the job have finished.
367374
if self.exec_log:
368-
num_log_lines=stream_log(self.connection, job, num_log_lines)
375+
num_log_lines=stream_log(self.connection, job, num_log_lines, http_timeout=self.http_timeout)
369376

370377
self.log.info("Job request has completed execution with the status: " + str(state))
371378
success = True
@@ -381,7 +388,7 @@ def _set_output_variables(self, context):
381388

382389
# retrieve variables from compute session
383390
uri = f"/compute/sessions/{self.compute_session_id}/variables?limit=999&filter=startsWith(name,'{self.output_macro_var_prefix}')"
384-
response = self.connection.get(uri, headers={'Accept': '*/*'})
391+
response = self.connection.get(uri, headers={'Accept': '*/*'}, timeout=self.http_timeout)
385392
if not response.ok:
386393
raise RuntimeError(f"get compute variables failed with {response.status_code}")
387394
v = response.json()["items"]

src/sas_airflow_provider/util/util.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,37 +22,38 @@
2222
import logging
2323

2424

25-
def get_folder_file_contents(session, path: str) -> str:
25+
def get_folder_file_contents(session, path: str, http_timeout=None) -> str:
2626
"""
2727
Fetch a file from folder service
2828
:param session:
2929
:param path:
30+
:param http_timeout: Timeout for http connection
3031
:return:
3132
"""
3233
member = get_member_by_path(session, path)
3334
if member['contentType'] != 'file':
3435
raise RuntimeError(f"folder item is not a file: '{path}'")
3536

3637
uri = member['uri'] + '/content'
37-
response = session.get(uri)
38+
response = session.get(uri, timeout=http_timeout)
3839
if not response.ok:
3940
raise RuntimeError(f"File {path} was not found or could not be accessed. error code: {response.status_code}")
4041

4142
return response.text
4243

4344

44-
def get_folder_by_path(session, path: str) -> dict:
45+
def get_folder_by_path(session, path: str, http_timeout=None) -> dict:
4546
"""
4647
Get a folder given the path.
4748
Return a folder object, or raise an error
4849
"""
49-
response = session.get('/folders/folders/@item', params={'path': path})
50+
response = session.get('/folders/folders/@item', params={'path': path}, timeout=http_timeout)
5051
if response.ok:
5152
return response.json()
5253
raise RuntimeError(response.text)
5354

5455

55-
def get_member_by_path(session, path: str) -> dict:
56+
def get_member_by_path(session, path: str, http_timeout=None) -> dict:
5657
"""
5758
Get a folder member given the full path.
5859
Return a folder member (object), or an empty dict if not found
@@ -61,12 +62,12 @@ def get_member_by_path(session, path: str) -> dict:
6162
if len(parts) < 2:
6263
raise RuntimeError(f"invalid path '{path}'")
6364

64-
f = get_folder_by_path(session, parts[0])
65+
f = get_folder_by_path(session, parts[0], http_timeout=http_timeout)
6566

6667
uri = get_uri(f['links'], 'members')
6768
if not uri:
6869
raise RuntimeError("failed to find members uri link")
69-
response = session.get(uri, params={'filter': f'eq("name","{parts[1]}")'})
70+
response = session.get(uri, params={'filter': f'eq("name","{parts[1]}")'}, timeout=http_timeout)
7071

7172
if not response.ok:
7273
raise RuntimeError(f"failed to get folder members for '{path}'")
@@ -79,18 +80,19 @@ def get_member_by_path(session, path: str) -> dict:
7980
return member
8081

8182

82-
def get_compute_session_file_contents(session, compute_session, path: str) -> str:
83+
def get_compute_session_file_contents(session, compute_session, path: str, http_timeout=None) -> str:
8384
"""
8485
Fetch a file from the compute session file system
8586
:param session: the rest session that includes auth token
8687
:param compute_session: the compute session id
8788
:param path: full path to the file in the file system
89+
:param http_timeout: Timeout for http connection
8890
:return: contents of the file
8991
"""
9092
p = f'{path.replace("/", "~fs~")}'
9193
uri = f'/compute/sessions/{compute_session}/files/{p}/content'
9294

93-
response = session.get(uri, headers={"Accept": "application/octet-stream"})
95+
response = session.get(uri, headers={"Accept": "application/octet-stream"}, timeout=http_timeout)
9496
if response.ok:
9597
return response.text
9698
raise RuntimeError(f"File {path} was not found or could not be accessed. error code: {response.status_code}")
@@ -107,19 +109,19 @@ def get_uri(links, rel):
107109
return link["uri"]
108110

109111

110-
def stream_log(session,job,start,limit=99999) -> int:
112+
def stream_log(session,job,start,limit=99999, http_timeout=None) -> int:
111113
current_line=start
112114

113115
log_uri = get_uri(job["links"], "log")
114116
if not log_uri:
115-
logging.getLogger(name=None).warning("Warning: failed to retrieve log URI from links")
117+
logging.getLogger(name=None).warning("Warning: failed to retrieve log URI. Maybe the log is too large.")
116118
else:
117119
try:
118120
# Note if it is a files link (it will be that when the job have finished), this does not support the 'start' parameter, so we need to filter it by ourself.
119121
# We will ignore the limit parameter in that case
120122
is_files_link=log_uri.startswith("/files/")
121123

122-
r = session.get(f"{log_uri}/content?start={start}&limit={limit}")
124+
r = session.get(f"{log_uri}/content?start={start}&limit={limit}", timeout=http_timeout)
123125
if r.ok:
124126
# Parse the json log format and print each line
125127
log_contents = r.text
@@ -134,26 +136,27 @@ def stream_log(session,job,start,limit=99999) -> int:
134136

135137
lines=lines+1
136138
else:
137-
logging.getLogger(name=None).warning(f"Failed to retrieve part of the log from URI: {log_uri}/content ")
139+
logging.getLogger(name=None).warning(f"Failed to retrieve parts of the log with status code {r.status_code} from URI: {log_uri}/content. Maybe the log is too large.")
138140
except Exception as e:
139-
logging.getLogger(name=None).warning("Unable to retrieve parts of the log.")
141+
logging.getLogger(name=None).warning(f"Unable to retrieve parts of the log: {e}. Maybe the log is too large.")
140142

141143
return current_line
142144

143145

144146

145-
def dump_logs(session, job):
147+
def dump_logs(session, job, http_timeout=None):
146148
"""
147149
Get the log from the job object
148150
:param session: rest session
149151
:param job: job object that should contain links object
152+
:param http_timeout: Timeout for http connection
150153
"""
151154

152155
log_uri = get_uri(job["links"], "log")
153156
if not log_uri:
154157
print("Warning: failed to retrieve log uri from links. Log will not be displayed")
155158
else:
156-
r = session.get(f"{log_uri}/content")
159+
r = session.get(f"{log_uri}/content", timeout=http_timeout)
157160
if not r.ok:
158161
print("Warning: failed to retrieve log content. Log will not be displayed")
159162

@@ -165,9 +168,9 @@ def dump_logs(session, job):
165168
if t != "title":
166169
print(f'{line["line"]}')
167170

168-
def find_named_compute_session(session: requests.Session, name: str) -> dict:
171+
def find_named_compute_session(session: requests.Session, name: str, http_timeout=None) -> dict:
169172
# find session with given name
170-
response = session.get(f"/compute/sessions?filter=eq(name, {name})")
173+
response = session.get(f"/compute/sessions?filter=eq(name, {name})", timeout=http_timeout)
171174
if not response.ok:
172175
raise RuntimeError(f"Find sessions failed: {response.status_code}")
173176
sessions = response.json()
@@ -176,18 +179,19 @@ def find_named_compute_session(session: requests.Session, name: str) -> dict:
176179
return sessions["items"][0]
177180
return {}
178181

179-
def create_or_connect_to_session(session: requests.Session, context_name: str, name = None) -> dict:
182+
def create_or_connect_to_session(session: requests.Session, context_name: str, name = None, http_timeout=None) -> dict:
180183
"""
181184
Connect to an existing compute session by name. If that named session does not exist,
182185
one is created using the context name supplied
183186
:param session: rest session that includes oauth token
184187
:param context_name: the context name to use to create the session if the session was not found
185188
:param name: name of session to find
189+
:param http_timeout: Timeout for http connection
186190
:return: session object
187191
188192
"""
189193
if name != None:
190-
compute_session = find_named_compute_session(session, name)
194+
compute_session = find_named_compute_session(session, name, http_timeout=http_timeout)
191195
if compute_session:
192196
return compute_session
193197

@@ -197,7 +201,7 @@ def create_or_connect_to_session(session: requests.Session, context_name: str, n
197201

198202

199203
# find compute context
200-
response = session.get("/compute/contexts", params={"filter": f'eq("name","{context_name}")'})
204+
response = session.get("/compute/contexts", params={"filter": f'eq("name","{context_name}")'},timeout=http_timeout)
201205
if not response.ok:
202206
raise RuntimeError(f"Find context named {context_name} failed: {response.status_code}")
203207
context_resp = response.json()
@@ -216,7 +220,7 @@ def create_or_connect_to_session(session: requests.Session, context_name: str, n
216220
headers = {"Content-Type": "application/vnd.sas.compute.session.request+json"}
217221

218222
req = json.dumps(session_request)
219-
response = session.post(uri, data=req, headers=headers)
223+
response = session.post(uri, data=req, headers=headers, timeout=http_timeout)
220224

221225
if response.status_code != 201:
222226
raise RuntimeError(f"Failed to create session: {response.text}")
@@ -226,9 +230,9 @@ def create_or_connect_to_session(session: requests.Session, context_name: str, n
226230

227231
return json_response
228232

229-
def end_compute_session(session: requests.Session, id):
233+
def end_compute_session(session: requests.Session, id, http_timeout=None):
230234
uri = f'/compute/sessions/{id}'
231-
response = session.delete(uri)
235+
response = session.delete(uri, timeout=http_timeout)
232236
if not response.ok:
233237
return False
234238
return True

0 commit comments

Comments
 (0)