Skip to content

Commit 2b45ed9

Browse files
committed
added healthcheck to the trino python client making it able to proccess queries that will take more than 15 minutes
1 parent 606a033 commit 2b45ed9

File tree

2 files changed

+196
-1
lines changed

2 files changed

+196
-1
lines changed

tests/unit/test_client.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,3 +1447,149 @@ def delete_password(self, servicename, username):
14471447
return None
14481448

14491449
os.remove(file_path)
1450+
1451+
1452+
@mock.patch("trino.client.TrinoRequest.http")
1453+
def test_trinoquery_heartbeat_success(mock_requests, sample_post_response_data, sample_get_response_data):
1454+
"""Test that heartbeat is sent periodically and does not stop on success."""
1455+
head_call_count = 0
1456+
def fake_head(url, timeout=10):
1457+
nonlocal head_call_count
1458+
head_call_count += 1
1459+
class Resp:
1460+
status_code = 200
1461+
return Resp()
1462+
mock_requests.head.side_effect = fake_head
1463+
mock_requests.Response.return_value.json.return_value = sample_post_response_data
1464+
mock_requests.get.return_value.json.return_value = sample_get_response_data
1465+
mock_requests.post.return_value.json.return_value = sample_post_response_data
1466+
req = TrinoRequest(
1467+
host="coordinator",
1468+
port=8080,
1469+
client_session=ClientSession(user="test"),
1470+
http_scheme="http",
1471+
)
1472+
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.1)
1473+
def finish_query(*args, **kwargs):
1474+
query._finished = True
1475+
return []
1476+
query.fetch = finish_query
1477+
query._next_uri = "http://coordinator/v1/statement/next"
1478+
query._row_mapper = mock.Mock(map=lambda x: [])
1479+
query._start_heartbeat()
1480+
time.sleep(0.3)
1481+
query._stop_heartbeat()
1482+
assert head_call_count >= 2
1483+
1484+
@mock.patch("trino.client.TrinoRequest.http")
1485+
def test_trinoquery_heartbeat_failure_stops(mock_requests, sample_post_response_data, sample_get_response_data):
1486+
"""Test that heartbeat stops after 3 consecutive failures."""
1487+
def fake_head(url, timeout=10):
1488+
class Resp:
1489+
status_code = 500
1490+
return Resp()
1491+
mock_requests.head.side_effect = fake_head
1492+
mock_requests.Response.return_value.json.return_value = sample_post_response_data
1493+
mock_requests.get.return_value.json.return_value = sample_get_response_data
1494+
mock_requests.post.return_value.json.return_value = sample_post_response_data
1495+
req = TrinoRequest(
1496+
host="coordinator",
1497+
port=8080,
1498+
client_session=ClientSession(user="test"),
1499+
http_scheme="http",
1500+
)
1501+
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05)
1502+
query._next_uri = "http://coordinator/v1/statement/next"
1503+
query._row_mapper = mock.Mock(map=lambda x: [])
1504+
query._start_heartbeat()
1505+
time.sleep(0.3)
1506+
assert not query._heartbeat_enabled
1507+
query._stop_heartbeat()
1508+
1509+
@mock.patch("trino.client.TrinoRequest.http")
1510+
def test_trinoquery_heartbeat_404_405_stops(mock_requests, sample_post_response_data, sample_get_response_data):
1511+
"""Test that heartbeat stops if server returns 404 or 405."""
1512+
for code in (404, 405):
1513+
def fake_head(url, timeout=10, code=code):
1514+
class Resp:
1515+
status_code = code
1516+
return Resp()
1517+
mock_requests.head.side_effect = fake_head
1518+
mock_requests.Response.return_value.json.return_value = sample_post_response_data
1519+
mock_requests.get.return_value.json.return_value = sample_get_response_data
1520+
mock_requests.post.return_value.json.return_value = sample_post_response_data
1521+
req = TrinoRequest(
1522+
host="coordinator",
1523+
port=8080,
1524+
client_session=ClientSession(user="test"),
1525+
http_scheme="http",
1526+
)
1527+
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05)
1528+
query._next_uri = "http://coordinator/v1/statement/next"
1529+
query._row_mapper = mock.Mock(map=lambda x: [])
1530+
query._start_heartbeat()
1531+
time.sleep(0.2)
1532+
assert not query._heartbeat_enabled
1533+
query._stop_heartbeat()
1534+
1535+
@mock.patch("trino.client.TrinoRequest.http")
1536+
def test_trinoquery_heartbeat_stops_on_finish(mock_requests, sample_post_response_data, sample_get_response_data):
1537+
"""Test that heartbeat stops when the query is finished."""
1538+
head_call_count = 0
1539+
def fake_head(url, timeout=10):
1540+
nonlocal head_call_count
1541+
head_call_count += 1
1542+
class Resp:
1543+
status_code = 200
1544+
return Resp()
1545+
mock_requests.head.side_effect = fake_head
1546+
mock_requests.Response.return_value.json.return_value = sample_post_response_data
1547+
mock_requests.get.return_value.json.return_value = sample_get_response_data
1548+
mock_requests.post.return_value.json.return_value = sample_post_response_data
1549+
req = TrinoRequest(
1550+
host="coordinator",
1551+
port=8080,
1552+
client_session=ClientSession(user="test"),
1553+
http_scheme="http",
1554+
)
1555+
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05)
1556+
query._next_uri = "http://coordinator/v1/statement/next"
1557+
query._row_mapper = mock.Mock(map=lambda x: [])
1558+
query._start_heartbeat()
1559+
time.sleep(0.1)
1560+
query._finished = True
1561+
time.sleep(0.1)
1562+
query._stop_heartbeat()
1563+
# Heartbeat should have stopped after query finished
1564+
assert head_call_count >= 1
1565+
1566+
@mock.patch("trino.client.TrinoRequest.http")
1567+
def test_trinoquery_heartbeat_stops_on_cancel(mock_requests, sample_post_response_data, sample_get_response_data):
1568+
"""Test that heartbeat stops when the query is cancelled."""
1569+
head_call_count = 0
1570+
def fake_head(url, timeout=10):
1571+
nonlocal head_call_count
1572+
head_call_count += 1
1573+
class Resp:
1574+
status_code = 200
1575+
return Resp()
1576+
mock_requests.head.side_effect = fake_head
1577+
mock_requests.Response.return_value.json.return_value = sample_post_response_data
1578+
mock_requests.get.return_value.json.return_value = sample_get_response_data
1579+
mock_requests.post.return_value.json.return_value = sample_post_response_data
1580+
req = TrinoRequest(
1581+
host="coordinator",
1582+
port=8080,
1583+
client_session=ClientSession(user="test"),
1584+
http_scheme="http",
1585+
)
1586+
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05)
1587+
query._next_uri = "http://coordinator/v1/statement/next"
1588+
query._row_mapper = mock.Mock(map=lambda x: [])
1589+
query._start_heartbeat()
1590+
time.sleep(0.1)
1591+
query._cancelled = True
1592+
time.sleep(0.1)
1593+
query._stop_heartbeat()
1594+
# Heartbeat should have stopped after query cancelled
1595+
assert head_call_count >= 1

trino/client.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,8 @@ def __init__(
808808
request: TrinoRequest,
809809
query: str,
810810
legacy_primitive_types: bool = False,
811-
fetch_mode: Literal["mapped", "segments"] = "mapped"
811+
fetch_mode: Literal["mapped", "segments"] = "mapped",
812+
heartbeat_interval: float = 60.0, # seconds
812813
) -> None:
813814
self._query_id: Optional[str] = None
814815
self._stats: Dict[Any, Any] = {}
@@ -826,6 +827,11 @@ def __init__(
826827
self._legacy_primitive_types = legacy_primitive_types
827828
self._row_mapper: Optional[RowMapper] = None
828829
self._fetch_mode = fetch_mode
830+
self._heartbeat_interval = heartbeat_interval
831+
self._heartbeat_thread = None
832+
self._heartbeat_stop_event = threading.Event()
833+
self._heartbeat_failures = 0
834+
self._heartbeat_enabled = True
829835

830836
@property
831837
def query_id(self) -> Optional[str]:
@@ -868,6 +874,39 @@ def result(self):
868874
def info_uri(self):
869875
return self._info_uri
870876

877+
def _start_heartbeat(self):
878+
if self._heartbeat_thread is not None:
879+
return
880+
self._heartbeat_stop_event.clear()
881+
self._heartbeat_thread = threading.Thread(target=self._heartbeat_loop, daemon=True)
882+
self._heartbeat_thread.start()
883+
884+
def _stop_heartbeat(self):
885+
self._heartbeat_stop_event.set()
886+
if self._heartbeat_thread is not None:
887+
self._heartbeat_thread.join(timeout=2)
888+
self._heartbeat_thread = None
889+
890+
def _heartbeat_loop(self):
891+
while not self._heartbeat_stop_event.is_set() and not self.finished and not self.cancelled and self._heartbeat_enabled:
892+
if self._next_uri is None:
893+
break
894+
try:
895+
response = self._request.http.head(self._next_uri, timeout=10)
896+
if response.status_code == 404 or response.status_code == 405:
897+
self._heartbeat_enabled = False
898+
break
899+
if response.status_code == 200:
900+
self._heartbeat_failures = 0
901+
else:
902+
self._heartbeat_failures += 1
903+
except Exception:
904+
self._heartbeat_failures += 1
905+
if self._heartbeat_failures >= 3:
906+
self._heartbeat_enabled = False
907+
break
908+
self._heartbeat_stop_event.wait(self._heartbeat_interval)
909+
871910
def execute(self, additional_http_headers=None) -> TrinoResult:
872911
"""Initiate a Trino query by sending the SQL statement
873912
@@ -895,6 +934,9 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
895934
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
896935
self._result = TrinoResult(self, rows)
897936

937+
# Start heartbeat thread
938+
self._start_heartbeat()
939+
898940
# Execute should block until at least one row is received or query is finished or cancelled
899941
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
900942
self._result.rows += self.fetch()
@@ -921,6 +963,7 @@ def fetch(self) -> List[Union[List[Any]], Any]:
921963
self._update_state(status)
922964
if status.next_uri is None:
923965
self._finished = True
966+
self._stop_heartbeat()
924967

925968
if not self._row_mapper:
926969
return []
@@ -968,6 +1011,7 @@ def cancel(self) -> None:
9681011
if response.status_code == requests.codes.no_content:
9691012
self._cancelled = True
9701013
logger.debug("query cancelled: %s", self.query_id)
1014+
self._stop_heartbeat()
9711015
return
9721016

9731017
self._request.raise_response_error(response)
@@ -985,6 +1029,11 @@ def finished(self) -> bool:
9851029
def cancelled(self) -> bool:
9861030
return self._cancelled
9871031

1032+
@property
1033+
def is_running(self) -> bool:
1034+
"""Return True if the query is still running (not finished or cancelled)."""
1035+
return not self.finished and not self.cancelled
1036+
9881037

9891038
def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts):
9901039
def wrapper(func):

0 commit comments

Comments
 (0)