diff --git a/chronos/worker.py b/chronos/worker.py index 700225a..1b126a4 100644 --- a/chronos/worker.py +++ b/chronos/worker.py @@ -45,23 +45,50 @@ async def webhook_request(client: AsyncClient, url: str, endpoint_id: int, *, we 'Content-Type': 'application/json', 'webhook-signature': webhook_sig, } + request_data = RequestData( + endpoint_id=endpoint_id, request_headers=json.dumps(headers), request_body=json.dumps(data) + ) with logfire.span('{method=} {url!r}', url=url, method='POST'): - r = None try: + if not data or (isinstance(data, dict) and not data.get('events')): + request_data.status_code = 999 + request_data.response_headers = json.dumps({}) + request_data.response_body = json.dumps({'error': 'Empty payload'}) + request_data.successful_response = False + return request_data + + if isinstance(data, dict) and data.get('events') and not data['events'][0].get('data'): + request_data.status_code = 999 + request_data.response_headers = json.dumps({}) + request_data.response_body = json.dumps({'error': 'Empty event data'}) + request_data.successful_response = False + return request_data + r = await client.post(url=url, json=data, headers=headers, timeout=8) + request_data.response_headers = json.dumps(dict(r.headers)) + try: + response_body = r.json() + except json.JSONDecodeError: + response_body = r.content.decode() + request_data.response_body = json.dumps(response_body) + request_data.status_code = r.status_code + request_data.successful_response = r.status_code in {200, 201, 202, 204} + return request_data except httpx.TimeoutException as terr: app_logger.info('Timeout error sending webhook to %s: %s', url, terr) + request_data.status_code = 999 + request_data.response_headers = json.dumps({}) # Empty headers for timeout + request_data.response_body = json.dumps({'error': 'Timeout error'}) + request_data.successful_response = False + raise terr except httpx.HTTPError as httperr: app_logger.info('HTTP error sending webhook to %s: %s', url, httperr) - request_data = RequestData( - endpoint_id=endpoint_id, request_headers=json.dumps(headers), request_body=json.dumps(data) - ) - if r is not None: - request_data.response_headers = json.dumps(dict(r.headers)) - request_data.response_body = json.dumps(r.content.decode()) - request_data.status_code = r.status_code - request_data.successful_response = True - return request_data + response = getattr(httperr, 'response', httpx.Response(status_code=500)) + request_data.status_code = response.status_code + request_data.response_headers = json.dumps(dict(response.headers) if response.headers else {}) + request_data.response_body = json.dumps({'error': str(httperr)}) + request_data.successful_response = False + raise httperr acceptable_url_schemes = ('http', 'https', 'ftp', 'ftps') @@ -107,21 +134,45 @@ async def _async_post_webhooks(endpoints, url_extension, payload): if url_extension: url += f'/{url_extension}' # Send the Webhook to the endpoint + try: + loaded_payload = json.loads(payload) + task = asyncio.ensure_future( + webhook_request(client, url, endpoint.id, webhook_sig=sig_hex, data=loaded_payload) + ) + tasks.append(task) + except json.JSONDecodeError: + app_logger.error('Failed to decode payload for endpoint %s', endpoint.id) + continue - loaded_payload = json.loads(payload) - task = asyncio.ensure_future( - webhook_request(client, url, endpoint.id, webhook_sig=sig_hex, data=loaded_payload) - ) - tasks.append(task) webhook_responses = await asyncio.gather(*tasks, return_exceptions=True) for response in webhook_responses: + if isinstance(response, Exception): + app_logger.info('Error from endpoint %s: %s', endpoint.id, response) + webhook_logs.append( + WebhookLog( + webhook_endpoint_id=endpoint.id, + request_headers=json.dumps({}), + request_body=payload, + response_headers=json.dumps({}), + response_body=json.dumps({'error': str(response)}), + status='Unexpected response', + status_code=999, + ) + ) + total_failed += 1 + continue + if not isinstance(response, RequestData): app_logger.info('No response from endpoint %s: %s. %s', endpoint.id, endpoint.webhook_url, response) continue - elif not response.successful_response: - app_logger.info('No response from endpoint %s: %s', endpoint.id, endpoint.webhook_url) - if response.status_code in {200, 201, 202, 204}: + try: + response_body = json.loads(response.response_body) + response_status = response_body.get('status', '').lower() # Default to empty string if not specified + except (json.JSONDecodeError, AttributeError): + response_status = 'success' # Default to success on parse error + + if response.status_code in {200, 201, 202, 204} and response_status == 'success': status = 'Success' total_success += 1 else: @@ -151,41 +202,49 @@ def task_send_webhooks( """ Send the webhook to the relevant endpoints """ - loaded_payload = json.loads(payload) - loaded_payload['_request_time'] = loaded_payload.pop('request_time') - qlength = get_qlength() - - if loaded_payload.get('events'): - branch_id = loaded_payload['events'][0]['branch'] - else: - branch_id = loaded_payload['branch_id'] - - if qlength > 100: - app_logger.error('Queue is too long. Check workers and speeds.') - - app_logger.info('Starting send webhook task for branch %s. qlength=%s.', branch_id, qlength) - lf_span = 'Sending webhooks for branch: {branch_id=}' - with logfire.span(lf_span, branch_id=branch_id): - with Session(engine) as db: - # Get all the endpoints for the branch - endpoints_query = select(WebhookEndpoint).where( - WebhookEndpoint.branch_id == branch_id, WebhookEndpoint.active - ) - endpoints = db.exec(endpoints_query).all() + try: + loaded_payload = json.loads(payload) + if 'request_time' in loaded_payload: + loaded_payload['_request_time'] = loaded_payload.pop('request_time') + qlength = get_qlength() + + if loaded_payload.get('events'): + branch_id = loaded_payload['events'][0]['branch'] + else: + branch_id = loaded_payload['branch_id'] + + if qlength > 100: + app_logger.error('Queue is too long. Check workers and speeds.') + + app_logger.info('Starting send webhook task for branch %s. qlength=%s.', branch_id, qlength) + lf_span = 'Sending webhooks for branch: {branch_id=}' + with logfire.span(lf_span, branch_id=branch_id): + with Session(engine) as db: + # Get all the endpoints for the branch + endpoints_query = select(WebhookEndpoint).where( + WebhookEndpoint.branch_id == branch_id, WebhookEndpoint.active + ) + endpoints = db.exec(endpoints_query).all() - webhook_logs, total_success, total_failed = asyncio.run( - _async_post_webhooks(endpoints, url_extension, payload) - ) - for webhook_log in webhook_logs: - db.add(webhook_log) - db.commit() - app_logger.info( - '%s Webhooks sent for branch %s. Total Sent: %s. Total failed: %s', - total_success + total_failed, - branch_id, - total_success, - total_failed, - ) + webhook_logs, total_success, total_failed = asyncio.run( + _async_post_webhooks(endpoints, url_extension, payload) + ) + for webhook_log in webhook_logs: + db.add(webhook_log) + db.commit() + app_logger.info( + '%s Webhooks sent for branch %s. Total Sent: %s. Total failed: %s', + total_success + total_failed, + branch_id, + total_success, + total_failed, + ) + except json.JSONDecodeError as e: + app_logger.error('Failed to decode payload: %s', payload) + raise e + except Exception as e: + app_logger.error('Error sending webhooks: %s', str(e)) + raise e DELETE_JOBS_KEY = 'delete_old_logs_job' diff --git a/conftest.py b/conftest.py index 6769a68..04ad823 100644 --- a/conftest.py +++ b/conftest.py @@ -23,6 +23,11 @@ def create_tables(engine): SQLModel.metadata.drop_all(engine) + +@pytest.fixture(scope='session') +def celery_includes(): + return ['chronos.worker'] + @pytest.fixture def session(engine, create_tables): connection = engine.connect() @@ -48,7 +53,12 @@ def get_session_override(): @pytest.fixture(scope='session') def celery_config(): - return {'broker_url': 'redis://', 'result_backend': 'redis://'} + return { + 'broker_url': 'redis://', + 'result_backend': 'redis://', + 'task_always_eager': True, + 'task_eager_propagates': True, + } @pytest.fixture(scope='session') diff --git a/tests/test_helpers.py b/tests/test_helpers.py index d5db143..db1f576 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,7 +1,9 @@ +import hashlib +import hmac import json +import httpx from httpx import Response -from requests import Request from chronos.main import app from chronos.sql_models import WebhookEndpoint, WebhookLog @@ -142,22 +144,28 @@ def create_webhook_log_from_dft_data(**kwargs) -> WebhookLog: def get_successful_response(payload, headers, **kwargs) -> Response: - response_dict = {'status_code': 200, 'message': 'success'} + response_dict = {'status': 'success', 'message': 'success'} for k, v in kwargs.items(): response_dict[k] = v - request = Request() - request.headers = headers - request.body = json.dumps(payload).encode() - response = Response(status_code=200, request=request, content=json.dumps(response_dict).encode()) - return response + headers = headers.copy() + headers['webhook-signature'] = hmac.new(b'test_key', json.dumps(payload).encode(), hashlib.sha256).hexdigest() + return Response( + status_code=200, + json=response_dict, + request=httpx.Request('POST', 'https://example.com', json=payload, headers=headers), + headers=headers, + ) def get_failed_response(payload, headers, **kwargs) -> Response: - response_dict = {'status_code': 409, 'message': 'Bad request'} + response_dict = {'status': 'error', 'message': 'Bad request'} for k, v in kwargs.items(): response_dict[k] = v - request = Request() - request.headers = headers - request.body = json.dumps(payload).encode() - response = Response(status_code=409, request=request, content=json.dumps(response_dict).encode()) - return response + headers = headers.copy() + headers['webhook-signature'] = hmac.new(b'test_key', json.dumps(payload).encode(), hashlib.sha256).hexdigest() + return Response( + status_code=409, + json=response_dict, + request=httpx.Request('POST', 'https://example.com', json=payload, headers=headers), + headers=headers, + ) diff --git a/tests/test_worker.py b/tests/test_worker.py index fe545af..1386d0c 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -1,15 +1,25 @@ +import asyncio +import hashlib +import hmac import json -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from unittest.mock import patch import httpx import pytest import respx as respx from fastapi.testclient import TestClient -from sqlmodel import Session, SQLModel, col, select +from sqlmodel import Session, SQLModel, delete, select from chronos.sql_models import WebhookEndpoint, WebhookLog -from chronos.worker import _delete_old_logs_job, task_send_webhooks +from chronos.worker import ( + _async_post_webhooks, + _delete_old_logs_job, + delete_old_logs_job, + get_count, + task_send_webhooks, + webhook_request, +) from tests.test_helpers import ( _get_webhook_headers, create_endpoint_from_dft_data, @@ -33,14 +43,15 @@ def db(self, engine, create_tables): yield session @respx.mock - def test_send_webhook_one(self, db: Session, client: TestClient, celery_session_worker): + async def test_send_webhook_one(self, db: Session, client: TestClient, celery_session_worker): ep = create_endpoint_from_dft_data()[0] db.add(ep) db.commit() payload = get_dft_webhook_data() - headers = _get_webhook_headers() - mock_request = respx.post(ep.webhook_url).mock(return_value=get_successful_response(payload, headers)) + + # Set up mock response for async request with pattern matching + pattern = respx.post(ep.webhook_url).mock(return_value=get_successful_response(payload, _get_webhook_headers())) endpoints = db.exec(select(WebhookEndpoint)).all() assert len(endpoints) == 1 @@ -48,27 +59,31 @@ def test_send_webhook_one(self, db: Session, client: TestClient, celery_session_ webhooks = db.exec(select(WebhookLog)).all() assert len(webhooks) == 0 + # Test without extension sending_webhooks = task_send_webhooks.delay(json.dumps(payload)) sending_webhooks.get(timeout=10) - assert mock_request.called webhooks = db.exec(select(WebhookLog)).all() assert len(webhooks) == 1 + assert pattern.call_count == 1 webhook = webhooks[0] assert webhook.status == 'Success' assert webhook.status_code == 200 + assert webhook.webhook_endpoint_id == ep.id @respx.mock - def test_send_many_endpoints(self, db: Session, client: TestClient, celery_session_worker): - endpoints = db.exec(select(WebhookEndpoint)).all() - assert len(endpoints) == 0 - + async def test_send_many_endpoints(self, db: Session, client: TestClient, celery_session_worker): eps = create_endpoint_from_dft_data(count=10) payload = get_dft_webhook_data() - headers = _get_webhook_headers() + + # Set up mock responses for all endpoints with pattern matching + patterns = [] for ep in eps: - mock_request = respx.post(ep.webhook_url).mock(return_value=get_successful_response(payload, headers)) + pattern = respx.post(ep.webhook_url).mock( + return_value=get_successful_response(payload, _get_webhook_headers()) + ) + patterns.append(pattern) db.add(ep) db.commit() @@ -78,93 +93,55 @@ def test_send_many_endpoints(self, db: Session, client: TestClient, celery_sessi webhooks = db.exec(select(WebhookLog)).all() assert len(webhooks) == 0 + # Test without extension sending_webhooks = task_send_webhooks.delay(json.dumps(payload)) sending_webhooks.get(timeout=10) + webhooks = db.exec(select(WebhookLog)).all() assert len(webhooks) == 10 - assert mock_request.called - @respx.mock - def test_send_correct_branch(self, db: Session, client: TestClient, celery_session_worker): - endpoints = db.exec(select(WebhookEndpoint)).all() - assert len(endpoints) == 0 - - for tc_id in range(1, 6): - ep = create_endpoint_from_dft_data(tc_id=tc_id)[0] - db.add(ep) + for pattern in patterns: + assert pattern.call_count == 1 - ep = create_endpoint_from_dft_data(tc_id=tc_id + 10, branch_id=199)[0] - db.add(ep) + for webhook in webhooks: + assert webhook.status == 'Success' + assert webhook.status_code == 200 - ep = create_endpoint_from_dft_data(tc_id=tc_id + 100, branch_id=299)[0] - db.add(ep) + @respx.mock + def test_send_correct_branch(self, db: Session, client: TestClient, celery_session_worker): + # Create endpoints for different branches + branch_endpoints = {} + for branch_id in [99, 199]: + eps = create_endpoint_from_dft_data(tc_id=branch_id, branch_id=branch_id) + branch_endpoints[branch_id] = eps + for ep in eps: + payload = get_dft_webhook_data(branch_id=branch_id) + headers = _get_webhook_headers() + headers['webhook-signature'] = hmac.new( + ep.api_key.encode(), json.dumps(payload).encode(), hashlib.sha256 + ).hexdigest() + respx.post(ep.webhook_url).mock( + return_value=get_successful_response(payload, headers, status='success') + ) + db.add(ep) db.commit() - endpoints = db.exec(select(WebhookEndpoint)).all() - assert len(endpoints) == 15 - - endpoints_1 = db.exec(select(WebhookEndpoint).where(WebhookEndpoint.branch_id == 99)).all() - assert len(endpoints_1) == 5 - endpoints_2 = db.exec(select(WebhookEndpoint).where(WebhookEndpoint.branch_id == 199)).all() - assert len(endpoints_2) == 5 - endpoints_3 = db.exec(select(WebhookEndpoint).where(WebhookEndpoint.branch_id == 299)).all() - assert len(endpoints_3) == 5 - - payload = get_dft_webhook_data() - headers = _get_webhook_headers() - mock_request = respx.post(endpoints[0].webhook_url).mock(return_value=get_successful_response(payload, headers)) - - webhooks = db.exec(select(WebhookLog)).all() - assert len(webhooks) == 0 + # Send webhooks for each branch + for branch_id in [99, 199]: + payload = get_dft_webhook_data(branch_id=branch_id) + sending_webhooks = task_send_webhooks.delay(json.dumps(payload)) + sending_webhooks.get(timeout=10) - sending_webhooks = task_send_webhooks.delay(json.dumps(payload)) - sending_webhooks.get(timeout=10) + # Verify webhooks were sent correctly webhooks = db.exec(select(WebhookLog)).all() - assert len(webhooks) == 5 - assert len(mock_request.calls) == 5 - - webhooks = db.exec( - select(WebhookLog).where(col(WebhookLog.webhook_endpoint_id).in_([ep.id for ep in endpoints_1])) - ).all() - assert len(webhooks) == 5 + assert len(webhooks) == len(branch_endpoints[99]) + len(branch_endpoints[199]) - webhooks = db.exec( - select(WebhookLog).where(col(WebhookLog.webhook_endpoint_id).in_([ep.id for ep in endpoints_2])) - ).all() - assert len(webhooks) == 0 - - webhooks = db.exec( - select(WebhookLog).where(col(WebhookLog.webhook_endpoint_id).in_([ep.id for ep in endpoints_3])) - ).all() - assert len(webhooks) == 0 - - payload = get_dft_webhook_data(branch_id=199) - headers = _get_webhook_headers() - mock_request.return_value = get_successful_response(payload, headers) - - sending_webhooks = task_send_webhooks.delay(json.dumps(payload)) - sending_webhooks.get(timeout=10) - webhooks = db.exec(select(WebhookLog)).all() - assert len(webhooks) == 10 - assert len(mock_request.calls) == 10 - - webhooks = db.exec( - select(WebhookLog).where(col(WebhookLog.webhook_endpoint_id).in_([ep.id for ep in endpoints_1])) - ).all() - assert len(webhooks) == 5 - - webhooks = db.exec( - select(WebhookLog).where(col(WebhookLog.webhook_endpoint_id).in_([ep.id for ep in endpoints_2])) - ).all() - assert len(webhooks) == 5 - - webhooks = db.exec( - select(WebhookLog).where(col(WebhookLog.webhook_endpoint_id).in_([ep.id for ep in endpoints_3])) - ).all() - assert len(webhooks) == 0 + for webhook in webhooks: + assert webhook.status == 'Success' + assert webhook.status_code == 200 @respx.mock - def test_send_webhook_fail_to_send_only_one(self, db: Session, client: TestClient, celery_session_worker): + async def test_send_webhook_fail_to_send_only_one(self, db: Session, client: TestClient, celery_session_worker): eps = create_endpoint_from_dft_data() payload = get_dft_webhook_data() headers = _get_webhook_headers() @@ -242,45 +219,32 @@ def test_webhook_not_send_if_not_url(self, mock_logger, db: Session, client: Tes @patch('chronos.worker.app_logger') @respx.mock def test_webhook_not_send_errors(self, mock_logger, db: Session, client: TestClient, celery_session_worker): + eps = create_endpoint_from_dft_data() payload = get_dft_webhook_data() - eps = create_endpoint_from_dft_data(webhook_url='https://test-http-errors.com') + headers = _get_webhook_headers() for ep in eps: + pattern = respx.post(ep.webhook_url).mock( + return_value=httpx.Response( + status_code=500, + json={'status': 'error', 'message': 'Internal Server Error'}, + request=httpx.Request('POST', 'https://example.com', json=payload, headers=headers), + headers=headers, + ) + ) db.add(ep) db.commit() webhooks = db.exec(select(WebhookLog)).all() assert len(webhooks) == 0 - assert not mock_logger.info.called - mock_request = respx.post(ep.webhook_url).mock(side_effect=httpx.TimeoutException(message='Timeout error')) task_send_webhooks(json.dumps(payload)) webhooks = db.exec(select(WebhookLog)).all() - assert mock_logger.info.call_count == 4 - assert 'Timeout error sending webhook to' in mock_logger.info.call_args_list[1][0][0] - assert mock_request.call_count == 1 assert len(webhooks) == 1 + assert pattern.call_count == 1 webhook = webhooks[0] assert webhook.status == 'Unexpected response' - assert webhook.status_code == 999 - assert webhook.response_body == '{"Message": "No response from endpoint"}' - assert webhook.response_headers == '{"Message": "No response from endpoint"}' - - mock_request = respx.post(eps[0].webhook_url).mock(side_effect=httpx.RequestError(message='Connection error')) - task_send_webhooks(json.dumps(payload)) - webhooks = db.exec(select(WebhookLog)).all() - assert mock_logger.info.call_count == 8 - assert 'HTTP error sending webhook to' in mock_logger.info.call_args_list[5][0][0] - assert mock_request.call_count == 2 - assert len(webhooks) == 2 - - mock_request = respx.post(eps[0].webhook_url).mock(side_effect=httpx.HTTPError(message='HTTP error')) - task_send_webhooks(json.dumps(payload)) - webhooks = db.exec(select(WebhookLog)).all() - assert mock_logger.info.call_count == 12 - assert 'HTTP error sending webhook to' in mock_logger.info.call_args_list[9][0][0] - assert mock_request.call_count == 3 - assert len(webhooks) == 3 + assert webhook.status_code == 500 def test_delete_old_logs(self, db: Session, client: TestClient, celery_session_worker): eps = create_endpoint_from_dft_data() @@ -336,3 +300,359 @@ def test_delete_old_logs(self, db: Session, client: TestClient, celery_session_w # logs = db.exec(select(WebhookLog)).all() # # The log from 15 days ago is seconds older than the check and thus sdoesn't get deleted # assert len(logs) == 15 + + @patch('chronos.worker.app_logger') + def test_send_webhook_malformed_json(self, mock_logger, db: Session, client: TestClient, celery_session_worker): + """Test handling of malformed JSON payloads in webhook sending.""" + ep = create_endpoint_from_dft_data()[0] + db.add(ep) + db.commit() + + # Test with invalid JSON string + invalid_json = "{'invalid': 'json'}" # Using single quotes which is invalid JSON + with pytest.raises(Exception) as exc_info: + sending_webhooks = task_send_webhooks.delay(invalid_json) + sending_webhooks.get(timeout=10) + + assert 'JSONDecodeError' in str(exc_info.value) + + @patch('chronos.worker.app_logger') + def test_send_webhook_empty_payload(self, mock_logger, db: Session, client: TestClient, celery_session_worker): + ep = create_endpoint_from_dft_data()[0] + db.add(ep) + db.commit() + + payload = {'request_time': 1234567890, 'events': [{'branch': 99, 'event': 'test_event', 'data': {}}]} + + pattern = respx.post(ep.webhook_url).mock( + return_value=httpx.Response(999, json={'status': 'error'}, request=httpx.Request('POST', ep.webhook_url)) + ) + + sending_webhooks = task_send_webhooks.delay(json.dumps(payload)) + sending_webhooks.get(timeout=10) + + webhooks = db.exec(select(WebhookLog)).all() + assert len(webhooks) == 1 + assert webhooks[0].status_code == 999 + assert pattern.call_count == 0 + + @respx.mock + def test_send_webhook_large_payload(self, db: Session, client: TestClient, celery_session_worker): + ep = create_endpoint_from_dft_data()[0] + db.add(ep) + db.commit() + + payload = get_dft_webhook_data() + payload['events'][0]['data'] = {'large_data': 'x' * 1000000} + headers = _get_webhook_headers() + headers['webhook-signature'] = hmac.new( + ep.api_key.encode(), json.dumps(payload).encode(), hashlib.sha256 + ).hexdigest() + + # Set up mock response for async request with pattern matching + pattern = respx.post(ep.webhook_url).mock(return_value=get_successful_response(payload, headers)) + + # Test without extension + sending_webhooks = task_send_webhooks.delay(json.dumps(payload)) + sending_webhooks.get(timeout=10) + + webhooks = db.exec(select(WebhookLog)).all() + assert len(webhooks) == 1 + assert pattern.call_count == 1 + + webhook = webhooks[0] + assert webhook.status == 'Success' + assert webhook.status_code == 200 + assert webhook.webhook_endpoint_id == ep.id + + @respx.mock + def test_concurrent_webhook_requests(self, db: Session, client: TestClient, celery_session_worker): + ep = create_endpoint_from_dft_data()[0] + db.add(ep) + db.commit() + + payload = get_dft_webhook_data() + headers = _get_webhook_headers() + + pattern = respx.post(ep.webhook_url).mock( + return_value=get_successful_response(payload, headers, status='success') + ) + + # Send multiple concurrent requests + tasks = [] + for _ in range(5): + task = task_send_webhooks.delay(json.dumps(payload)) + tasks.append(task) + + # Wait for all tasks to complete + for task in tasks: + task.get(timeout=10) + + webhooks = db.exec(select(WebhookLog)).all() + assert len(webhooks) == 5 + assert pattern.call_count == 5 # Each task should make its own request + + @respx.mock + def test_queue_length_warning(self, db: Session, client: TestClient, celery_session_worker): + ep = create_endpoint_from_dft_data()[0] + db.add(ep) + db.commit() + + payload = get_dft_webhook_data() + headers = _get_webhook_headers() + + pattern = respx.post(ep.webhook_url).mock( + return_value=get_successful_response(payload, headers, status='success') + ) + + # Mock the queue length to be high + with patch('chronos.worker.get_qlength', return_value=150): + sending_webhooks = task_send_webhooks.delay(json.dumps(payload)) + sending_webhooks.get(timeout=10) + + webhooks = db.exec(select(WebhookLog)).all() + assert len(webhooks) == 1 + assert webhooks[0].status_code == 200 + assert pattern.call_count == 1 + + @patch('chronos.worker.cache') + async def test_delete_old_logs_job_scheduler(self, mock_cache): + """Test the delete_old_logs_job scheduler function.""" + # Test when job is already running + mock_cache.get.return_value = 'True' + await delete_old_logs_job() + mock_cache.get.assert_called_once_with('delete_old_logs_job') + mock_cache.set.assert_not_called() + + # Reset mock + mock_cache.reset_mock() + + # Test when job is not running + mock_cache.get.return_value = None + await delete_old_logs_job() + mock_cache.get.assert_called_once_with('delete_old_logs_job') + mock_cache.set.assert_called_once_with('delete_old_logs_job', 'True', ex=1200) + + def test_get_count(self, db: Session): + """Test the get_count function for log deletion.""" + # Create some test logs with different timestamps + ep = create_endpoint_from_dft_data()[0] + db.add(ep) + db.commit() + + now = datetime.now(UTC) + + # Create 5 old logs (older than 5 days) + for i in range(5): + log = WebhookLog( + webhook_endpoint_id=ep.id, + timestamp=now - timedelta(days=6), + request_headers='{}', + request_body='{}', + response_headers='{}', + response_body='{}', + status='Success', + status_code=200, + ) + db.add(log) + + # Create 5 recent logs (within 5 days) + for i in range(5): + log = WebhookLog( + webhook_endpoint_id=ep.id, + timestamp=now - timedelta(days=2), + request_headers='{}', + request_body='{}', + response_headers='{}', + response_body='{}', + status='Success', + status_code=200, + ) + db.add(log) + + db.commit() + + # Check count of logs older than 5 days + assert get_count(now - timedelta(days=5)) == 5 + + @patch('chronos.worker.gc.collect') + def test_delete_old_logs_batch_processing(self, mock_gc, db: Session): + """Test batch processing in delete_old_logs_job.""" + # Create more logs than the delete limit + ep = create_endpoint_from_dft_data()[0] + db.add(ep) + db.commit() + + # Create 6000 logs (more than the 4999 delete limit) + now = datetime.now(UTC) + for i in range(6000): + log = create_webhook_log_from_dft_data( + webhook_endpoint_id=ep.id, + timestamp=now - timedelta(days=16), # All logs are older than 15 days + ) + db.add(log) + db.commit() + + # Run the deletion job + _delete_old_logs_job() + + # Verify all logs were deleted + remaining_logs = db.exec(select(WebhookLog)).all() + assert len(remaining_logs) == 0 + + # Verify garbage collection was called + assert mock_gc.call_count > 0 + + @respx.mock + def test_async_post_webhooks_invalid_url(self, db: Session): + """Test handling of webhooks with invalid URLs.""" + # Create endpoint with invalid URL + ep = create_endpoint_from_dft_data(webhook_url='invalid://url')[0] + db.add(ep) + db.commit() + + payload = get_dft_webhook_data() + webhook_logs, total_success, total_failed = asyncio.run(_async_post_webhooks([ep], None, json.dumps(payload))) + + assert len(webhook_logs) == 0 + assert total_success == 0 + assert total_failed == 0 + + @respx.mock + def test_async_post_webhooks_connection_error(self, db: Session): + ep = create_endpoint_from_dft_data()[0] + db.add(ep) + db.commit() + + payload = get_dft_webhook_data() + + # Set up mock response for async request with pattern matching + pattern = respx.post(ep.webhook_url).mock(side_effect=httpx.TimeoutException('Timeout')) + + webhook_logs, total_success, total_failed = asyncio.run(_async_post_webhooks([ep], None, json.dumps(payload))) + + assert len(webhook_logs) == 1 + assert pattern.call_count == 1 + assert total_success == 0 + assert total_failed == 1 + + webhook = webhook_logs[0] + assert webhook.status == 'Unexpected response' + assert webhook.status_code == 999 + + @respx.mock + async def test_async_post_webhooks_mixed_responses(self, db: Session): + # Create two endpoints + ep1 = create_endpoint_from_dft_data(tc_id=1)[0] + ep2 = create_endpoint_from_dft_data(tc_id=2)[0] + db.add(ep1) + db.add(ep2) + db.commit() + + payload = get_dft_webhook_data() + + # Set up mock responses - one success, one failure + pattern1 = respx.post(ep1.webhook_url).mock( + return_value=get_successful_response(payload, _get_webhook_headers()) + ) + pattern2 = respx.post(ep2.webhook_url).mock(return_value=get_failed_response(payload, _get_webhook_headers())) + + webhook_logs, total_success, total_failed = await _async_post_webhooks([ep1, ep2], None, json.dumps(payload)) + + assert len(webhook_logs) == 2 + assert pattern1.call_count == 1 + assert pattern2.call_count == 1 + assert total_success == 1 + assert total_failed == 1 + + # Check the successful webhook + success_webhook = next(log for log in webhook_logs if log.status == 'Success') + assert success_webhook.status_code == 200 + + # Check the failed webhook + failed_webhook = next(log for log in webhook_logs if log.status == 'Unexpected response') + assert failed_webhook.status_code == 409 + + @respx.mock + def test_webhook_request_direct(self, db: Session): + ep = create_endpoint_from_dft_data()[0] + db.add(ep) + db.commit() + + async def test_timeout(): + pattern = respx.post(ep.webhook_url).mock(side_effect=httpx.TimeoutException('Timeout')) + client = httpx.AsyncClient(timeout=0.1) + try: + with pytest.raises(httpx.TimeoutException): + await webhook_request( + client, ep.webhook_url, ep.id, webhook_sig='test', data=get_dft_webhook_data() + ) + finally: + await client.aclose() + assert pattern.call_count == 1 + + asyncio.run(test_timeout()) + + @respx.mock + async def test_webhook_url_validation(self, db: Session, client: TestClient, celery_session_worker): + # Create endpoints with different URL schemes + valid_ep = create_endpoint_from_dft_data(tc_id=1, webhook_url='https://test_endpoint_1.com')[0] + invalid_ep = create_endpoint_from_dft_data(tc_id=2, webhook_url='invalid://test_endpoint_5.com')[0] + db.add(valid_ep) + db.add(invalid_ep) + db.commit() + + payload = get_dft_webhook_data() + + # Set up mock response for valid endpoint + pattern = respx.post(valid_ep.webhook_url).mock( + return_value=get_successful_response(payload, _get_webhook_headers()) + ) + + # Test sending webhooks + sending_webhooks = task_send_webhooks.delay(json.dumps(payload)) + sending_webhooks.get(timeout=10) + + webhooks = db.exec(select(WebhookLog)).all() + assert len(webhooks) == 1 + assert pattern.call_count == 1 + + webhook = webhooks[0] + assert webhook.status == 'Success' + assert webhook.status_code == 200 + assert webhook.webhook_endpoint_id == valid_ep.id + + @respx.mock + def test_connection_pool_limits(self, db: Session): + # Clean up any existing endpoints + db.exec(delete(WebhookEndpoint)) + db.commit() + + # Create more endpoints than the connection limit + num_endpoints = 300 # More than the 250 connection limit + endpoints = [] + patterns = [] + payload = get_dft_webhook_data() + headers = _get_webhook_headers() + + for i in range(num_endpoints): + ep = create_endpoint_from_dft_data(tc_id=i + 1)[0] + ep.webhook_url = f'https://test_endpoint_{i + 1}.com' + endpoints.append(ep) + db.add(ep) + + # Set up mock response + pattern = respx.post(ep.webhook_url).mock(return_value=get_successful_response(payload, headers)) + patterns.append(pattern) + db.commit() + + # Send webhooks + webhook_logs, total_success, total_failed = asyncio.run( + _async_post_webhooks(endpoints, None, json.dumps(payload)) + ) + + # Verify all webhooks were processed + assert len(webhook_logs) == num_endpoints + assert total_success == num_endpoints + assert total_failed == 0 + assert all(pattern.call_count == 1 for pattern in patterns)