|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import os |
4 | | -from unittest.mock import patch |
| 4 | +import time |
| 5 | +from unittest.mock import Mock, patch |
5 | 6 |
|
6 | 7 | import pytest |
7 | 8 |
|
|
10 | 11 | from src.snowflake.connector.vendored.requests import Response |
11 | 12 |
|
12 | 13 |
|
13 | | -def build_response(status_code=200, headers=None): |
| 14 | +def build_response(content: bytes = b"", status_code: int = 200, headers=None): |
14 | 15 | response = Response() |
| 16 | + response._content = content |
15 | 17 | response.status_code = status_code |
16 | 18 | response.headers = headers |
17 | 19 | return response |
@@ -129,6 +131,66 @@ def test_timeout_handling(self, unavailable_metadata_service): |
129 | 131 | assert "has_gcp_identity_timeout" in result |
130 | 132 | assert "has_azure_managed_identity_timeout" in result |
131 | 133 |
|
| 134 | + def test_detect_platforms_executes_in_parallel(self): |
| 135 | + sleep_time = 2 |
| 136 | + |
| 137 | + def slow_requests_get(*args, **kwargs): |
| 138 | + time.sleep(sleep_time) |
| 139 | + return build_response( |
| 140 | + status_code=200, headers={"Metadata-Flavor": "Google"} |
| 141 | + ) |
| 142 | + |
| 143 | + def slow_boto3_client(*args, **kwargs): |
| 144 | + time.sleep(sleep_time) |
| 145 | + mock_client = Mock() |
| 146 | + mock_client.get_caller_identity.return_value = { |
| 147 | + "Arn": "arn:aws:iam::123456789012:user/TestUser" |
| 148 | + } |
| 149 | + return mock_client |
| 150 | + |
| 151 | + def slow_imds_get_request(*args, **kwargs): |
| 152 | + time.sleep(sleep_time) |
| 153 | + return build_response(content=b"content", status_code=200) |
| 154 | + |
| 155 | + def slow_imds_fetch_token(*args, **kwargs): |
| 156 | + return "test-token" |
| 157 | + |
| 158 | + # Mock all the network calls that run in parallel |
| 159 | + with patch( |
| 160 | + "snowflake.connector.platform_detection.requests.get", |
| 161 | + side_effect=slow_requests_get, |
| 162 | + ), patch( |
| 163 | + "snowflake.connector.platform_detection.boto3.client", |
| 164 | + side_effect=slow_boto3_client, |
| 165 | + ), patch( |
| 166 | + "snowflake.connector.platform_detection.IMDSFetcher._get_request", |
| 167 | + side_effect=slow_imds_get_request, |
| 168 | + ), patch( |
| 169 | + "snowflake.connector.platform_detection.IMDSFetcher._fetch_metadata_token", |
| 170 | + side_effect=slow_imds_fetch_token, |
| 171 | + ): |
| 172 | + start_time = time.time() |
| 173 | + result = detect_platforms(platform_detection_timeout_seconds=10) |
| 174 | + end_time = time.time() |
| 175 | + |
| 176 | + execution_time = end_time - start_time |
| 177 | + |
| 178 | + # Check that I/O calls are made in parallel. We shouldn't expect more than 2x the amount of time a single |
| 179 | + # I/O operation takes. Which in this case is 2 seconds. |
| 180 | + assert ( |
| 181 | + execution_time < 2 * sleep_time |
| 182 | + ), f"Expected parallel execution to take <4s, but took {execution_time:.2f}s" |
| 183 | + assert ( |
| 184 | + execution_time >= sleep_time |
| 185 | + ), f"Expected at least 2s due to sleep, but took {execution_time:.2f}s" |
| 186 | + |
| 187 | + assert "is_ec2_instance" in result |
| 188 | + assert "has_aws_identity" in result |
| 189 | + assert "is_azure_vm" in result |
| 190 | + assert "has_azure_managed_identity" in result |
| 191 | + assert "is_gce_vm" in result |
| 192 | + assert "has_gcp_identity" in result |
| 193 | + |
132 | 194 | @pytest.mark.parametrize( |
133 | 195 | "arn", |
134 | 196 | [ |
|
0 commit comments