diff --git a/model-engine/example_sleep_model_deployment/service_config.yaml b/model-engine/example_sleep_model_deployment/service_config.yaml new file mode 100644 index 000000000..97c530e75 --- /dev/null +++ b/model-engine/example_sleep_model_deployment/service_config.yaml @@ -0,0 +1,27 @@ +bundle_config: + model_bundle_name: sleep-model-timeout-test + request_schema: Dict[str, Any] + response_schema: Dict[str, Any] + repository: sleep_model + tag: latest + command: + - python + - app.py + readiness_initial_delay_seconds: 30 + +endpoint_config: + endpoint_name: sleep-model-timeout-test + model_bundle: sleep-model-timeout-test + cpus: 1 + memory: 2Gi + storage: 10Gi + gpus: 0 + min_workers: 1 + max_workers: 1 + per_worker: 1 + endpoint_type: async + queue_message_timeout_duration: 90 # 90 seconds to handle 70s inference + buffer + labels: + team: test + product: sleep-model-timeout-test + update_if_exists: True diff --git a/model-engine/example_sleep_model_deployment/sleep_model/Dockerfile b/model-engine/example_sleep_model_deployment/sleep_model/Dockerfile new file mode 100644 index 000000000..3e043d68b --- /dev/null +++ b/model-engine/example_sleep_model_deployment/sleep_model/Dockerfile @@ -0,0 +1,16 @@ +FROM python:3.9-slim + +WORKDIR /app + +# Copy requirements and install dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY app.py . + +# Expose port +EXPOSE 8080 + +# Run the application +CMD ["gunicorn", "--bind", "0.0.0.0:8080", "--timeout", "120", "app:app"] diff --git a/model-engine/example_sleep_model_deployment/sleep_model/app.py b/model-engine/example_sleep_model_deployment/sleep_model/app.py new file mode 100644 index 000000000..ef843f1ca --- /dev/null +++ b/model-engine/example_sleep_model_deployment/sleep_model/app.py @@ -0,0 +1,50 @@ +""" +Simple sleep model for testing queue timeout duration. +This model sleeps for 70 seconds to test queue lock duration > 60 seconds. +""" + +import time +from typing import Any, Dict +from flask import Flask, request, jsonify + +app = Flask(__name__) + +@app.route('/predict', methods=['POST']) +def predict(): + """ + Prediction endpoint that sleeps for 70 seconds to test queue timeout. + """ + try: + data = request.get_json() + sleep_duration = data.get('sleep_duration', 70) # Default 70 seconds + + print(f"Starting inference... will sleep for {sleep_duration} seconds") + + # Sleep to simulate long-running inference + time.sleep(sleep_duration) + + response = { + "result": f"Completed after sleeping for {sleep_duration} seconds", + "input": data, + "status": "success" + } + + print(f"Inference completed successfully after {sleep_duration} seconds") + return jsonify(response) + + except Exception as e: + print(f"Error during inference: {e}") + return jsonify({"error": str(e), "status": "failed"}), 500 + +@app.route('/health', methods=['GET']) +def health(): + """Health check endpoint""" + return jsonify({"status": "healthy"}) + +@app.route('/readyz', methods=['GET']) +def ready(): + """Readiness check endpoint""" + return jsonify({"status": "ready"}) + +if __name__ == '__main__': + app.run(host='0.0.0.0', port=8080) diff --git a/model-engine/example_sleep_model_deployment/sleep_model/requirements.txt b/model-engine/example_sleep_model_deployment/sleep_model/requirements.txt new file mode 100644 index 000000000..7eb4708cc --- /dev/null +++ b/model-engine/example_sleep_model_deployment/sleep_model/requirements.txt @@ -0,0 +1,2 @@ +Flask==2.3.3 +gunicorn==21.2.0 diff --git a/model-engine/model_engine_server/common/dtos/endpoint_builder.py b/model-engine/model_engine_server/common/dtos/endpoint_builder.py index bdad051db..2a3db8fe4 100644 --- a/model-engine/model_engine_server/common/dtos/endpoint_builder.py +++ b/model-engine/model_engine_server/common/dtos/endpoint_builder.py @@ -35,6 +35,7 @@ class BuildEndpointRequest(BaseModel): high_priority: Optional[bool] = None default_callback_url: Optional[str] = None default_callback_auth: Optional[CallbackAuth] = None + queue_message_timeout_duration: Optional[int] = None class BuildEndpointStatus(str, Enum): diff --git a/model-engine/model_engine_server/common/dtos/model_endpoints.py b/model-engine/model_engine_server/common/dtos/model_endpoints.py index 36a7c7f68..ff2d9f3c7 100644 --- a/model-engine/model_engine_server/common/dtos/model_endpoints.py +++ b/model-engine/model_engine_server/common/dtos/model_endpoints.py @@ -73,6 +73,7 @@ class CreateModelEndpointV1Request(BaseModel): default_callback_url: Optional[HttpUrlStr] = None default_callback_auth: Optional[CallbackAuth] = None public_inference: Optional[bool] = Field(default=False) + queue_message_timeout_duration: Optional[int] = Field(default=None, ge=1) class CreateModelEndpointV1Response(BaseModel): @@ -100,6 +101,7 @@ class UpdateModelEndpointV1Request(BaseModel): default_callback_url: Optional[HttpUrlStr] = None default_callback_auth: Optional[CallbackAuth] = None public_inference: Optional[bool] = None + queue_message_timeout_duration: Optional[int] = Field(default=None, ge=1) class UpdateModelEndpointV1Response(BaseModel): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index aa2f4e9d2..f5bcd6820 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1405,6 +1405,7 @@ async def execute( default_callback_url=request.default_callback_url, default_callback_auth=request.default_callback_auth, public_inference=request.public_inference, + queue_message_timeout_duration=request.queue_message_timeout_duration, ) _handle_post_inference_hooks( created_by=user.user_id, diff --git a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py index ea8466430..ddd56923b 100644 --- a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py @@ -389,6 +389,7 @@ async def execute( default_callback_url=request.default_callback_url, default_callback_auth=request.default_callback_auth, public_inference=request.public_inference, + queue_message_timeout_duration=request.queue_message_timeout_duration, ) _handle_post_inference_hooks( created_by=user.user_id, diff --git a/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py b/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py index bca30e10a..13ca40a98 100644 --- a/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py @@ -76,6 +76,7 @@ def create_model_endpoint_infra( billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str], default_callback_auth: Optional[CallbackAuth], + queue_message_timeout_duration: Optional[int] = None, ) -> str: deployment_name = generate_deployment_name( model_endpoint_record.created_by, model_endpoint_record.name @@ -104,6 +105,7 @@ def create_model_endpoint_infra( billing_tags=billing_tags, default_callback_url=default_callback_url, default_callback_auth=default_callback_auth, + queue_message_timeout_duration=queue_message_timeout_duration, ) response = self.task_queue_gateway.send_task( task_name=BUILD_TASK_NAME, diff --git a/model-engine/model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py index 3799ed654..099e25829 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py @@ -1,9 +1,10 @@ import os -from typing import Any, Dict +from datetime import timedelta +from typing import Any, Dict, Optional from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError from azure.identity import DefaultAzureCredential -from azure.servicebus.management import ServiceBusAdministrationClient +from azure.servicebus.management import ServiceBusAdministrationClient, QueueProperties from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import EndpointResourceInfraException from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( @@ -32,13 +33,36 @@ async def create_queue_if_not_exists( endpoint_name: str, endpoint_created_by: str, endpoint_labels: Dict[str, Any], + queue_message_timeout_duration: Optional[int] = None, ) -> QueueInfo: queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + timeout_duration = queue_message_timeout_duration or 60 # Default to 60 seconds + + # Validation: Azure Service Bus lock duration must be <= 5 minutes (300s) + if timeout_duration > 300: + raise ValueError(f"queue_message_timeout_duration ({timeout_duration}s) exceeds Azure Service Bus maximum of 300 seconds") + with _get_servicebus_administration_client() as client: try: + # First, try to create the queue with default properties client.create_queue(queue_name=queue_name) + + # Then update the queue properties to set custom lock duration + queue_properties = client.get_queue(queue_name) + queue_properties.lock_duration = timedelta(seconds=timeout_duration) + client.update_queue(queue_properties) + except ResourceExistsError: - pass + # Queue already exists, update its properties if needed + try: + queue_properties = client.get_queue(queue_name) + # Only update if the lock duration is different + if queue_properties.lock_duration != timedelta(seconds=timeout_duration): + queue_properties.lock_duration = timedelta(seconds=timeout_duration) + client.update_queue(queue_properties) + except Exception as e: + # If we can't update properties, log but don't fail + logger.warning(f"Could not update queue properties for {queue_name}: {e}") return QueueInfo(queue_name, None) diff --git a/model-engine/model_engine_server/infra/gateways/resources/fake_queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/fake_queue_endpoint_resource_delegate.py index 9ded2d6e5..b43e5c4cc 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/fake_queue_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/fake_queue_endpoint_resource_delegate.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Sequence +from typing import Any, Dict, Optional, Sequence from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( QueueEndpointResourceDelegate, @@ -15,6 +15,7 @@ async def create_queue_if_not_exists( endpoint_name: str, endpoint_created_by: str, endpoint_labels: Dict[str, Any], + queue_message_timeout_duration: Optional[int] = None, ) -> QueueInfo: queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) queue_url = f"http://foobar.com/{queue_name}" diff --git a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py index 4e7759747..86c56602b 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py @@ -38,6 +38,7 @@ async def create_queue( self, endpoint_record: ModelEndpointRecord, labels: Dict[str, str], + queue_message_timeout_duration: Optional[int] = None, ) -> QueueInfo: """Creates a new queue, returning its unique name and queue URL.""" queue_name, queue_url = await self.queue_delegate.create_queue_if_not_exists( @@ -45,6 +46,7 @@ async def create_queue( endpoint_name=endpoint_record.name, endpoint_created_by=endpoint_record.created_by, endpoint_labels=labels, + queue_message_timeout_duration=queue_message_timeout_duration, ) return QueueInfo(queue_name, queue_url) @@ -56,7 +58,11 @@ async def create_or_update_resources( request.build_endpoint_request.model_endpoint_record.endpoint_type == ModelEndpointType.ASYNC ): - q = await self.create_queue(endpoint_record, request.build_endpoint_request.labels) + q = await self.create_queue( + endpoint_record, + request.build_endpoint_request.labels, + request.build_endpoint_request.queue_message_timeout_duration + ) queue_name: Optional[str] = q.queue_name queue_url: Optional[str] = q.queue_url else: diff --git a/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py index 76c77e64b..cbfb8d3bb 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py @@ -24,6 +24,7 @@ async def create_queue_if_not_exists( endpoint_name: str, endpoint_created_by: str, endpoint_labels: Dict[str, Any], + queue_message_timeout_duration: Optional[int] = None, ) -> QueueInfo: """ Creates a queue associated with the given endpoint_id. Other fields are set as tags on the queue. diff --git a/model-engine/model_engine_server/infra/gateways/resources/sqs_queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/sqs_queue_endpoint_resource_delegate.py index 748c3f699..d8df66635 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/sqs_queue_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/sqs_queue_endpoint_resource_delegate.py @@ -55,7 +55,10 @@ async def create_queue_if_not_exists( endpoint_name: str, endpoint_created_by: str, endpoint_labels: Dict[str, Any], + queue_message_timeout_duration: Optional[int] = None, ) -> QueueInfo: + timeout_duration = queue_message_timeout_duration or 60 # Default to 60 seconds + async with _create_async_sqs_client(sqs_profile=self.sqs_profile) as sqs_client: queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) @@ -73,9 +76,7 @@ async def create_queue_if_not_exists( create_response = await sqs_client.create_queue( QueueName=queue_name, Attributes=dict( - VisibilityTimeout="43200", - # To match current hardcoded Celery timeout of 24hr - # However, the max SQS visibility is 12hrs. + VisibilityTimeout=str(timeout_duration), Policy=_get_queue_policy(queue_name=queue_name), ), tags=_get_queue_tags( diff --git a/model-engine/model_engine_server/infra/services/live_batch_job_service.py b/model-engine/model_engine_server/infra/services/live_batch_job_service.py index 413d05aae..a08168d78 100644 --- a/model-engine/model_engine_server/infra/services/live_batch_job_service.py +++ b/model-engine/model_engine_server/infra/services/live_batch_job_service.py @@ -127,6 +127,7 @@ async def create_batch_job( owner=owner, default_callback_url=None, default_callback_auth=None, + queue_message_timeout_duration=None, ) await self.batch_job_record_repository.update_batch_job_record( diff --git a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py index 1b32104a5..d355336a0 100644 --- a/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py +++ b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py @@ -504,7 +504,7 @@ def get_base_image_params( return BuildImageRequest( repo=hmi_config.user_inference_base_repository, image_tag=resulting_image_tag[:MAX_IMAGE_TAG_LEN], - aws_profile=ECR_AWS_PROFILE, # type: ignore + aws_profile=build_endpoint_request.aws_role, # Use user-provided aws_role instead of hardcoded ECR_AWS_PROFILE base_path=WORKSPACE_PATH, dockerfile=f"{inference_folder}/{dockerfile}", base_image=base_image, @@ -577,7 +577,7 @@ def _get_user_image_params( return BuildImageRequest( repo=ecr_repo, image_tag=service_image_tag[:MAX_IMAGE_TAG_LEN], - aws_profile=ECR_AWS_PROFILE, + aws_profile=build_endpoint_request.aws_role, # Use user-provided aws_role instead of hardcoded ECR_AWS_PROFILE base_path=WORKSPACE_PATH, dockerfile=f"{inference_folder}/{dockerfile}", base_image=base_image, @@ -633,7 +633,7 @@ def _get_inject_bundle_image_params( return BuildImageRequest( repo=ecr_repo, image_tag=service_image_tag[:MAX_IMAGE_TAG_LEN], - aws_profile=ECR_AWS_PROFILE, + aws_profile=build_endpoint_request.aws_role, # Use user-provided aws_role instead of hardcoded ECR_AWS_PROFILE base_path=WORKSPACE_PATH, dockerfile=f"{inference_folder}/{dockerfile}", base_image=base_image, @@ -667,7 +667,7 @@ async def _build_image( if not self.docker_repository.image_exists( repository_name=image_params.repo, image_tag=image_params.image_tag, - aws_profile=ECR_AWS_PROFILE, + aws_profile=build_endpoint_request.aws_role, # Use user-provided aws_role instead of hardcoded ECR_AWS_PROFILE ): self.monitoring_metrics_gateway.emit_image_build_cache_miss_metric(image_type) tags = [ @@ -713,7 +713,7 @@ async def _build_image( if not build_result_status and not self.docker_repository.image_exists( repository_name=image_params.repo, image_tag=image_params.image_tag, - aws_profile=ECR_AWS_PROFILE, + aws_profile=build_endpoint_request.aws_role, # Use user-provided aws_role instead of hardcoded ECR_AWS_PROFILE ): log_error( f"Image build failed for endpoint {model_endpoint_name}, user {user_id}" diff --git a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py index 6c28f4990..b66e4ffd9 100644 --- a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py +++ b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py @@ -164,6 +164,7 @@ async def create_model_endpoint( default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth], public_inference: Optional[bool] = False, + queue_message_timeout_duration: Optional[int] = None, ) -> ModelEndpointRecord: existing_endpoints = ( await self.model_endpoint_record_repository.list_model_endpoint_records( @@ -209,6 +210,7 @@ async def create_model_endpoint( high_priority=high_priority, default_callback_url=default_callback_url, default_callback_auth=default_callback_auth, + queue_message_timeout_duration=queue_message_timeout_duration, ) await self.model_endpoint_record_repository.update_model_endpoint_record( model_endpoint_id=model_endpoint_record.id, diff --git a/model-engine/tests/unit/infra/gateways/resources/test_asb_queue_endpoint_resource_delegate.py b/model-engine/tests/unit/infra/gateways/resources/test_asb_queue_endpoint_resource_delegate.py new file mode 100644 index 000000000..99307e4c7 --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/resources/test_asb_queue_endpoint_resource_delegate.py @@ -0,0 +1,93 @@ +import pytest +from unittest.mock import patch, MagicMock +from datetime import timedelta + +from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import ( + ASBQueueEndpointResourceDelegate, +) + + +class TestASBQueueEndpointResourceDelegate: + @pytest.fixture + def delegate(self): + return ASBQueueEndpointResourceDelegate() + + @pytest.fixture + def mock_servicebus_client(self): + with patch( + "model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate._get_servicebus_administration_client" + ) as mock_client: + yield mock_client + + @pytest.mark.asyncio + async def test_create_queue_with_default_timeout(self, delegate, mock_servicebus_client): + """Test queue creation with default 60-second timeout""" + mock_client = MagicMock() + mock_servicebus_client.return_value.__enter__.return_value = mock_client + + result = await delegate.create_queue_if_not_exists( + endpoint_id="test-endpoint", + endpoint_name="test-endpoint", + endpoint_created_by="test-user", + endpoint_labels={"team": "test"}, + queue_message_timeout_duration=60, # Default + ) + + # Verify queue creation was called with correct properties + mock_client.create_queue.assert_called_once() + args, kwargs = mock_client.create_queue.call_args + + assert "queue_properties" in kwargs + queue_properties = kwargs["queue_properties"] + assert queue_properties.lock_duration == timedelta(seconds=60) + assert result.queue_name == "launch-endpoint-id-test-endpoint" + + @pytest.mark.asyncio + async def test_create_queue_with_custom_timeout(self, delegate, mock_servicebus_client): + """Test queue creation with custom timeout duration""" + mock_client = MagicMock() + mock_servicebus_client.return_value.__enter__.return_value = mock_client + + await delegate.create_queue_if_not_exists( + endpoint_id="test-endpoint", + endpoint_name="test-endpoint", + endpoint_created_by="test-user", + endpoint_labels={"team": "test"}, + queue_message_timeout_duration=180, # 3 minutes + ) + + # Verify queue creation was called with custom timeout + args, kwargs = mock_client.create_queue.call_args + queue_properties = kwargs["queue_properties"] + assert queue_properties.lock_duration == timedelta(seconds=180) + + @pytest.mark.asyncio + async def test_create_queue_timeout_validation_error(self, delegate, mock_servicebus_client): + """Test that timeout > 300 seconds raises ValidationError""" + with pytest.raises(ValueError, match="exceeds Azure Service Bus maximum of 300 seconds"): + await delegate.create_queue_if_not_exists( + endpoint_id="test-endpoint", + endpoint_name="test-endpoint", + endpoint_created_by="test-user", + endpoint_labels={"team": "test"}, + queue_message_timeout_duration=400, # > 300 seconds + ) + + @pytest.mark.asyncio + async def test_create_queue_max_allowed_timeout(self, delegate, mock_servicebus_client): + """Test queue creation with maximum allowed timeout (300s)""" + mock_client = MagicMock() + mock_servicebus_client.return_value.__enter__.return_value = mock_client + + await delegate.create_queue_if_not_exists( + endpoint_id="test-endpoint", + endpoint_name="test-endpoint", + endpoint_created_by="test-user", + endpoint_labels={"team": "test"}, + queue_message_timeout_duration=300, # Exactly 5 minutes + ) + + # Should succeed without error + args, kwargs = mock_client.create_queue.call_args + queue_properties = kwargs["queue_properties"] + assert queue_properties.lock_duration == timedelta(seconds=300) diff --git a/model-engine/tests/unit/infra/gateways/resources/test_sqs_queue_endpoint_resource_delegate.py b/model-engine/tests/unit/infra/gateways/resources/test_sqs_queue_endpoint_resource_delegate.py index ae00ac439..278c548b9 100644 --- a/model-engine/tests/unit/infra/gateways/resources/test_sqs_queue_endpoint_resource_delegate.py +++ b/model-engine/tests/unit/infra/gateways/resources/test_sqs_queue_endpoint_resource_delegate.py @@ -335,6 +335,58 @@ def mock_create_async_sqs_client_get_queue_attributes_queue_throws_exception(): yield mock_create_async_sqs_client +@pytest.mark.asyncio +async def test_sqs_create_queue_with_custom_timeout_duration( + build_endpoint_request_async_custom: BuildEndpointRequest, + mock_create_async_sqs_client_create_queue, +): + """Test SQS queue creation with custom timeout duration""" + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") + endpoint_record: ModelEndpointRecord = build_endpoint_request_async_custom.model_endpoint_record + + await delegate.create_queue_if_not_exists( + endpoint_id=endpoint_record.id, + endpoint_name=endpoint_record.name, + endpoint_created_by=endpoint_record.created_by, + endpoint_labels=build_endpoint_request_async_custom.labels, + queue_message_timeout_duration=180, # 3 minutes + ) + + # Verify that create_queue was called with the custom timeout + mock_client = mock_create_async_sqs_client_create_queue.__aenter__.return_value + mock_client.create_queue.assert_called_once() + args, kwargs = mock_client.create_queue.call_args + + # Check that VisibilityTimeout was set to our custom value + assert kwargs["Attributes"]["VisibilityTimeout"] == "180" + + +@pytest.mark.asyncio +async def test_sqs_create_queue_with_default_timeout_duration( + build_endpoint_request_async_custom: BuildEndpointRequest, + mock_create_async_sqs_client_create_queue, +): + """Test SQS queue creation with default timeout duration""" + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") + endpoint_record: ModelEndpointRecord = build_endpoint_request_async_custom.model_endpoint_record + + await delegate.create_queue_if_not_exists( + endpoint_id=endpoint_record.id, + endpoint_name=endpoint_record.name, + endpoint_created_by=endpoint_record.created_by, + endpoint_labels=build_endpoint_request_async_custom.labels, + queue_message_timeout_duration=60, # Default + ) + + # Verify that create_queue was called with the default timeout + mock_client = mock_create_async_sqs_client_create_queue.__aenter__.return_value + mock_client.create_queue.assert_called_once() + args, kwargs = mock_client.create_queue.call_args + + # Check that VisibilityTimeout was set to default value + assert kwargs["Attributes"]["VisibilityTimeout"] == "60" + + @pytest.mark.asyncio async def test_sqs_create_or_update_resources_endpoint_exists( build_endpoint_request_async_custom: BuildEndpointRequest, diff --git a/test_azure_queue_timeout.py b/test_azure_queue_timeout.py new file mode 100644 index 000000000..f6d3921e6 --- /dev/null +++ b/test_azure_queue_timeout.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +""" +Integration test script for Azure Service Bus queue timeout duration. +This script tests the actual Azure Service Bus queue creation with custom lock duration. + +Usage: + export SERVICEBUS_NAMESPACE="your-namespace" + export AZURE_TENANT_ID="your-tenant-id" + export AZURE_CLIENT_ID="your-client-id" + export AZURE_CLIENT_SECRET="your-client-secret" + python test_azure_queue_timeout.py +""" + +import asyncio +import os +import sys +from datetime import timedelta +from azure.servicebus.management import ServiceBusAdministrationClient, QueueProperties +from azure.identity import DefaultAzureCredential + +async def test_queue_creation_with_timeout(): + """Test creating Azure Service Bus queue with custom lock duration""" + + # Check required environment variables + namespace = os.getenv("SERVICEBUS_NAMESPACE") + if not namespace: + print("ERROR: SERVICEBUS_NAMESPACE environment variable not set") + return False + + try: + # Create Service Bus Administration Client + client = ServiceBusAdministrationClient( + f"{namespace}.servicebus.windows.net", + credential=DefaultAzureCredential() + ) + + test_cases = [ + {"timeout": 60, "description": "default timeout (60s)"}, + {"timeout": 120, "description": "custom timeout (120s)"}, + {"timeout": 300, "description": "maximum timeout (300s)"}, + ] + + for i, test_case in enumerate(test_cases): + queue_name = f"test-queue-timeout-{i}-{test_case['timeout']}s" + timeout_seconds = test_case["timeout"] + + print(f"\n--- Testing {test_case['description']} ---") + + try: + # Create queue with custom lock duration + queue_properties = QueueProperties( + lock_duration=timedelta(seconds=timeout_seconds) + ) + + print(f"Creating queue: {queue_name}") + client.create_queue(queue_name, queue_properties=queue_properties) + print(f"✓ Queue created successfully") + + # Verify the queue properties + queue_props = client.get_queue(queue_name) + actual_lock_duration = queue_props.lock_duration.total_seconds() + + print(f"Expected lock duration: {timeout_seconds}s") + print(f"Actual lock duration: {actual_lock_duration}s") + + if actual_lock_duration == timeout_seconds: + print(f"✓ Lock duration matches expected value") + else: + print(f"✗ Lock duration mismatch!") + return False + + # Clean up - delete the test queue + client.delete_queue(queue_name) + print(f"✓ Test queue deleted") + + except Exception as e: + print(f"✗ Error testing {test_case['description']}: {e}") + return False + + # Test validation error for timeout > 300s + print(f"\n--- Testing validation error for timeout > 300s ---") + try: + queue_properties = QueueProperties( + lock_duration=timedelta(seconds=400) # Should fail + ) + client.create_queue("test-invalid-timeout", queue_properties=queue_properties) + print("✗ Should have failed for timeout > 300s") + return False + except Exception as e: + print(f"✓ Correctly rejected timeout > 300s: {e}") + + print(f"\n🎉 All tests passed!") + return True + + except Exception as e: + print(f"✗ Connection or authentication error: {e}") + print("Make sure you have the correct Azure credentials set up") + return False + +if __name__ == "__main__": + success = asyncio.run(test_queue_creation_with_timeout()) + sys.exit(0 if success else 1)