Skip to content

Commit bdacd5c

Browse files
added check in gce vm check to not throw an exception if headers don't exist. Don't want to fail everything there due to Attribute Error. Updated test to include a parallel timing test
1 parent f3319c5 commit bdacd5c

File tree

2 files changed

+65
-3
lines changed

2 files changed

+65
-3
lines changed

src/snowflake/connector/platform_detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def is_gce_vm(platform_detection_timeout_seconds: float):
262262
)
263263
return (
264264
_DetectionState.DETECTED
265-
if response.headers.get("Metadata-Flavor") == "Google"
265+
if response.headers and response.headers.get("Metadata-Flavor") == "Google"
266266
else _DetectionState.NOT_DETECTED
267267
)
268268
except requests.Timeout:

test/unit/test_detect_platforms.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

33
import os
4-
from unittest.mock import patch
4+
import time
5+
from unittest.mock import Mock, patch
56

67
import pytest
78

@@ -10,8 +11,9 @@
1011
from src.snowflake.connector.vendored.requests import Response
1112

1213

13-
def build_response(status_code=200, headers=None):
14+
def build_response(content: bytes = b"", status_code: int = 200, headers=None):
1415
response = Response()
16+
response._content = content
1517
response.status_code = status_code
1618
response.headers = headers
1719
return response
@@ -129,6 +131,66 @@ def test_timeout_handling(self, unavailable_metadata_service):
129131
assert "has_gcp_identity_timeout" in result
130132
assert "has_azure_managed_identity_timeout" in result
131133

134+
def test_detect_platforms_executes_in_parallel(self):
135+
sleep_time = 2
136+
137+
def slow_requests_get(*args, **kwargs):
138+
time.sleep(sleep_time)
139+
return build_response(
140+
status_code=200, headers={"Metadata-Flavor": "Google"}
141+
)
142+
143+
def slow_boto3_client(*args, **kwargs):
144+
time.sleep(sleep_time)
145+
mock_client = Mock()
146+
mock_client.get_caller_identity.return_value = {
147+
"Arn": "arn:aws:iam::123456789012:user/TestUser"
148+
}
149+
return mock_client
150+
151+
def slow_imds_get_request(*args, **kwargs):
152+
time.sleep(sleep_time)
153+
return build_response(content=b"content", status_code=200)
154+
155+
def slow_imds_fetch_token(*args, **kwargs):
156+
return "test-token"
157+
158+
# Mock all the network calls that run in parallel
159+
with patch(
160+
"snowflake.connector.platform_detection.requests.get",
161+
side_effect=slow_requests_get,
162+
), patch(
163+
"snowflake.connector.platform_detection.boto3.client",
164+
side_effect=slow_boto3_client,
165+
), patch(
166+
"snowflake.connector.platform_detection.IMDSFetcher._get_request",
167+
side_effect=slow_imds_get_request,
168+
), patch(
169+
"snowflake.connector.platform_detection.IMDSFetcher._fetch_metadata_token",
170+
side_effect=slow_imds_fetch_token,
171+
):
172+
start_time = time.time()
173+
result = detect_platforms(platform_detection_timeout_seconds=10)
174+
end_time = time.time()
175+
176+
execution_time = end_time - start_time
177+
178+
# Check that I/O calls are made in parallel. We shouldn't expect more than 2x the amount of time a single
179+
# I/O operation takes. Which in this case is 2 seconds.
180+
assert (
181+
execution_time < 2 * sleep_time
182+
), f"Expected parallel execution to take <4s, but took {execution_time:.2f}s"
183+
assert (
184+
execution_time >= sleep_time
185+
), f"Expected at least 2s due to sleep, but took {execution_time:.2f}s"
186+
187+
assert "is_ec2_instance" in result
188+
assert "has_aws_identity" in result
189+
assert "is_azure_vm" in result
190+
assert "has_azure_managed_identity" in result
191+
assert "is_gce_vm" in result
192+
assert "has_gcp_identity" in result
193+
132194
@pytest.mark.parametrize(
133195
"arn",
134196
[

0 commit comments

Comments
 (0)