diff --git a/.secrets.baseline b/.secrets.baseline index 76fd341..11cd8a2 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -178,9 +178,9 @@ "filename": "tests/test_s3_endpoint.py", "hashed_secret": "08d2e98e6754af941484848930ccbaddfefe13d6", "is_verified": false, - "line_number": 77 + "line_number": 82 } ] }, - "generated_at": "2026-03-02T20:27:56Z" + "generated_at": "2026-03-03T22:13:17Z" } diff --git a/gen3workflow/app.py b/gen3workflow/app.py index 6d30c38..d31059c 100644 --- a/gen3workflow/app.py +++ b/gen3workflow/app.py @@ -66,7 +66,16 @@ async def get_status(): root_path=config["DOCS_URL_PREFIX"], generate_unique_id_function=generate_unique_route_id, ) - app.async_client = httpx_client or httpx.AsyncClient() + + # `async_client` is used to hit the TES API, the Arborist service and AWS S3. + # Calls to S3 tend to timeout when uploading large files (and we might also be rate-limited). + # AsyncHTTPTransport supports retrying on httpx.ConnectError or httpx.ConnectTimeout. + # The `httpx_client` parameter is not meant to be used in production. It allows mocking + # external calls when testing. + app.async_client = httpx_client or httpx.AsyncClient( + transport=httpx.AsyncHTTPTransport(retries=3), timeout=120 + ) + app.include_router(ga4gh_tes_router, tags=["GA4GH TES"]) app.include_router(s3_router, tags=["S3"]) app.include_router(storage_router, tags=["Storage"]) diff --git a/gen3workflow/routes/s3.py b/gen3workflow/routes/s3.py index c1db539..18820f7 100644 --- a/gen3workflow/routes/s3.py +++ b/gen3workflow/routes/s3.py @@ -192,10 +192,6 @@ async def s3_endpoint(path: str, request: Request): # get the name of the user's bucket and ensure the user is making a call to their own bucket logger.info(f"Incoming S3 request from user '{user_id}': '{request.method} {path}'") user_bucket = aws_utils.get_safe_name_from_hostname(user_id) - if request.method == "GET" and path == "s3": - err_msg = f"'ls' not supported, use 'ls s3://{user_bucket}' instead" - logger.error(err_msg) - raise HTTPException(HTTP_400_BAD_REQUEST, err_msg) request_bucket = path.split("?")[0].split("/")[0] if request_bucket != user_bucket: err_msg = f"'{path}' (bucket '{request_bucket}') not allowed. You can make calls to your personal bucket, '{user_bucket}'" @@ -282,8 +278,17 @@ async def s3_endpoint(path: str, request: Request): assert credentials, "No AWS credentials found" headers["x-amz-security-token"] = credentials.token - # if this is a PUT request, we need the KMS key ID to use for encryption - if config["KMS_ENCRYPTION_ENABLED"] and request.method == "PUT": + # If this is a PUT or POST request, specify the KMS key to use for encryption. + # For multipart uploads, the initial CreateMultipartUpload request includes the KMS + # configuration, and the following UploadPart and CompleteMultipartUpload requests do not. + # We know this is an UploadPart or CompleteMultipartUpload request if it includes the + # uploadId query parameter. + query_params = dict(request.query_params) + if ( + config["KMS_ENCRYPTION_ENABLED"] + and request.method in ["PUT", "POST"] + and "uploadId" not in query_params + ): _, kms_key_arn = aws_utils.get_existing_kms_key_for_bucket(user_bucket) if not kms_key_arn: err_msg = "Bucket misconfigured. Hit the `GET /storage/setup` endpoint and try again." @@ -300,7 +305,6 @@ async def s3_endpoint(path: str, request: Request): f"{key.lower()}:{headers[key]}\n" for key in sorted_headers ) signed_headers = ";".join([k.lower() for k in sorted_headers]) - query_params = dict(request.query_params) # the query params in the canonical request have to be sorted: query_params_names = sorted(list(query_params.keys())) canonical_query_params = "&".join( @@ -339,9 +343,6 @@ async def s3_endpoint(path: str, request: Request): ) s3_api_url = f"https://{user_bucket}.s3.{region}.amazonaws.com/{api_endpoint}" logger.debug(f"Outgoing S3 request: '{request.method} {s3_api_url}'") - - # TODO: Enclose this with a retry if S3 response with a 500 error (which is possible! Failing - # fast can break a whole nextflow workflow) response = await request.app.async_client.request( method=request.method, url=s3_api_url, diff --git a/tests/conftest.py b/tests/conftest.py index 721acf5..5a1124a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -226,10 +226,33 @@ def mock_tes_server_request_function( return httpx.Response(status_code=status_code, json=out, text=text) +def mock_aws_s3_request_function(url: str): + """ + Mock responses from AWS S3 + """ + resp_xml = MOCKED_S3_RESPONSE_XML + headers = {"content-type": "application/xml"} + + # multipart upload special case: + if "test_s3_upload_file_multipart.txt" in url: + upload_id = "test-upload-id" + # "InitiateMultipartUploadResult" with "UploadId" + resp_xml = f"""\ngen3wf-{config['HOSTNAME']}-{TEST_USER_ID}test_s3_chunk_upload.txt{upload_id}""" + if f"?uploadId={upload_id}&partNumber=" in url: + headers["etag"] = "test-etag" + + return httpx.Response( + status_code=200, + text=resp_xml, + headers=headers, + ) + + # making these functions into mocks allows tests to check the requests that were made, for # example: `mock_tes_server_request.assert_called_with(...)` mock_tes_server_request = MagicMock(side_effect=mock_tes_server_request_function) mock_arborist_request = MagicMock(side_effect=mock_arborist_request_function) +mock_aws_s3_request = MagicMock(side_effect=mock_aws_s3_request_function) @pytest_asyncio.fixture(scope="function", autouse=True) @@ -316,20 +339,7 @@ async def handle_request(request: Request): f"https://gen3wf-{config['HOSTNAME']}-{TEST_USER_ID}.s3.{config['USER_BUCKETS_REGION']}.amazonaws.com" ): # mock calls to AWS S3 - resp_xml = MOCKED_S3_RESPONSE_XML - headers = {"content-type": "application/xml"} - # multipart upload special case: - if "test_s3_upload_file_multipart.txt" in url: - upload_id = "test-upload-id" - # "InitiateMultipartUploadResult" with "UploadId" - resp_xml = f"""\ngen3wf-{config['HOSTNAME']}-{TEST_USER_ID}test_s3_chunk_upload.txt{upload_id}""" - if f"?uploadId={upload_id}&partNumber=" in url: - headers["etag"] = "test-etag" - mocked_response = httpx.Response( - status_code=200, - text=resp_xml, - headers=headers, - ) + mocked_response = mock_aws_s3_request(url) if mocked_response is not None: print(f"Mocking request '{request.method} {url}'") @@ -339,7 +349,7 @@ async def handle_request(request: Request): httpx_client_function = getattr(httpx.AsyncClient(), request.method.lower()) return await httpx_client_function(url) - # set the httpx clients used by the app and by the Arborist client to mock clients that + # the httpx clients used by the app and by gen3authz are set to mock clients that # call `handle_request` mock_httpx_client = httpx.AsyncClient(transport=httpx.MockTransport(handle_request)) app = get_app(httpx_client=mock_httpx_client) diff --git a/tests/test_s3_endpoint.py b/tests/test_s3_endpoint.py index 8964e05..da40c79 100644 --- a/tests/test_s3_endpoint.py +++ b/tests/test_s3_endpoint.py @@ -8,7 +8,12 @@ from fastapi import HTTPException import pytest -from conftest import MOCKED_S3_RESPONSE_DICT, TEST_USER_ID, TEST_USER_TOKEN +from conftest import ( + MOCKED_S3_RESPONSE_DICT, + TEST_USER_ID, + TEST_USER_TOKEN, + mock_aws_s3_request, +) from gen3workflow.config import config from gen3workflow.routes.s3 import ( set_access_token_and_get_user_id, @@ -350,23 +355,33 @@ def test_s3_upload_file(s3_client, access_token_patcher, multipart): Test that the boto3 `upload_file` function works with the `/s3` endpoint, both for a small file uploaded in 1 part and for a large file uploaded in multiple parts. """ + bucket_name = f"gen3wf-{config['HOSTNAME']}-{TEST_USER_ID}" + object_key = f"test_s3_upload_file{'_multipart' if multipart else ''}.txt" + with patch( "gen3workflow.aws_utils.get_existing_kms_key_for_bucket", lambda _: ("test_kms_key_alias", "test_kms_key_arn"), ): - with tempfile.NamedTemporaryFile(mode="w+t", delete=True) as file_to_upload: - file_to_upload.write("Test file contents\n") + with tempfile.NamedTemporaryFile(delete=True) as file_to_upload: + file_to_upload.write(b"A" * (6 * 1024 * 1024)) # create a 6MB file + file_to_upload.flush() s3_client.upload_file( file_to_upload.name, - f"gen3wf-{config['HOSTNAME']}-{TEST_USER_ID}", - f"test_s3_upload_file{'_multipart' if multipart else ''}.txt", - # to test a multipart upload, set the chunk size to 1 to force splitting the file - # into multiple chunks: - Config=boto3.s3.transfer.TransferConfig( - multipart_threshold=1 if multipart else 9999 + bucket_name, + object_key, + # to test a multipart upload, set the part size to 1 to force splitting the file + # into multiple parts: + Config=( + boto3.s3.transfer.TransferConfig(multipart_threshold=1) + if multipart + else None ), ) + mock_aws_s3_request.assert_called_with( + f"https://{bucket_name}.s3.us-east-1.amazonaws.com/{object_key}{'?uploadId=test-upload-id' if multipart else ''}" + ) + def test_chunked_to_non_chunked_body(): """