Skip to content

Commit 30b037c

Browse files
updated client side parameter to include units. changed to only use float to be more explicit
1 parent bee50ef commit 30b037c

File tree

3 files changed

+46
-42
lines changed

3 files changed

+46
-42
lines changed

src/snowflake/connector/auth/_auth.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def base_auth_data(
102102
login_timeout: int | None = None,
103103
network_timeout: int | None = None,
104104
socket_timeout: int | None = None,
105-
platform_detection_timeout: int | float | None = None,
105+
platform_detection_timeout: float | None = None,
106106
):
107107
return {
108108
"data": {
@@ -124,7 +124,9 @@ def base_auth_data(
124124
"LOGIN_TIMEOUT": login_timeout,
125125
"NETWORK_TIMEOUT": network_timeout,
126126
"SOCKET_TIMEOUT": socket_timeout,
127-
"PLATFORM": detect_platforms(timeout=platform_detection_timeout),
127+
"PLATFORM": detect_platforms(
128+
timeout_seconds=platform_detection_timeout
129+
),
128130
},
129131
},
130132
}
@@ -180,7 +182,7 @@ def authenticate(
180182
self._rest._connection.login_timeout,
181183
self._rest._connection._network_timeout,
182184
self._rest._connection._socket_timeout,
183-
self._rest._connection._platform_detection_timeout,
185+
self._rest._connection._platform_detection_timeout_seconds,
184186
)
185187

186188
body = copy.deepcopy(body_template)

src/snowflake/connector/connection.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,9 @@ def _get_private_bytes_from_file(
197197
), # network timeout (infinite by default)
198198
"socket_timeout": (None, (type(None), int)),
199199
"external_browser_timeout": (120, int),
200-
"platform_detection_timeout": (
200+
"platform_detection_timeout_seconds": (
201201
None,
202-
(type(None), int, float),
202+
(type(None), float),
203203
), # Platform detection timeout for CSP metadata endpoints
204204
"backoff_policy": (DEFAULT_BACKOFF_POLICY, Callable),
205205
"passcode_in_password": (False, bool), # Snowflake MFA
@@ -698,12 +698,12 @@ def client_session_keep_alive_heartbeat_frequency(self, value) -> None:
698698
self._validate_client_session_keep_alive_heartbeat_frequency()
699699

700700
@property
701-
def platform_detection_timeout(self) -> int | float | None:
702-
return self._platform_detection_timeout
701+
def platform_detection_timeout_seconds(self) -> float | None:
702+
return self._platform_detection_timeout_seconds
703703

704-
@platform_detection_timeout.setter
705-
def platform_detection_timeout(self, value) -> None:
706-
self._platform_detection_timeout = value
704+
@platform_detection_timeout_seconds.setter
705+
def platform_detection_timeout_seconds(self, value) -> None:
706+
self._platform_detection_timeout_seconds = value
707707

708708
@property
709709
def client_prefetch_threads(self) -> int:

src/snowflake/connector/platform_detection.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class _DetectionState(Enum):
2121
TIMEOUT = "timeout"
2222

2323

24-
def is_ec2_instance(timeout):
24+
def is_ec2_instance(timeout_seconds: float):
2525
"""
2626
Check if the current environment is running on an AWS EC2 instance.
2727
@@ -31,13 +31,13 @@ def is_ec2_instance(timeout):
3131
It will ignore the token if on IMDSv1 and use the token if on IMDSv2.
3232
3333
Args:
34-
timeout: Timeout value for the metadata service request.
34+
timeout_seconds: Timeout value for the metadata service request.
3535
3636
Returns:
3737
_DetectionState: DETECTED if running on EC2, NOT_DETECTED otherwise.
3838
"""
3939
try:
40-
fetcher = IMDSFetcher(timeout=timeout, num_attempts=1)
40+
fetcher = IMDSFetcher(timeout=timeout_seconds, num_attempts=1)
4141
document = fetcher._get_request(
4242
"/latest/dynamic/instance-identity/document",
4343
None,
@@ -86,23 +86,23 @@ def is_valid_arn_for_wif(arn: str) -> bool:
8686
return any(re.match(p, arn) for p in patterns)
8787

8888

89-
def has_aws_identity(timeout):
89+
def has_aws_identity(timeout_seconds: float):
9090
"""
9191
Check if the current environment has a valid AWS identity for authentication.
9292
9393
If we retrieve an ARN from the caller identity and it is a valid WIF ARN,
9494
then we assume we have a valid AWS identity for authentication.
9595
9696
Args:
97-
timeout: Timeout value for AWS API calls.
97+
timeout_seconds: Timeout value for AWS API calls.
9898
9999
Returns:
100100
_DetectionState: DETECTED if valid AWS identity exists, NOT_DETECTED otherwise.
101101
"""
102102
try:
103103
config = Config(
104-
connect_timeout=timeout,
105-
read_timeout=timeout,
104+
connect_timeout=timeout_seconds,
105+
read_timeout=timeout_seconds,
106106
retries={"total_max_attempts": 1},
107107
)
108108
caller_identity = boto3.client("sts", config=config).get_caller_identity()
@@ -118,15 +118,15 @@ def has_aws_identity(timeout):
118118
return _DetectionState.NOT_DETECTED
119119

120120

121-
def is_azure_vm(timeout):
121+
def is_azure_vm(timeout_seconds: float):
122122
"""
123123
Check if the current environment is running on an Azure Virtual Machine.
124124
125125
If we query the Azure Instance Metadata Service and receive an HTTP 200 response,
126126
then we assume we are running on an Azure VM.
127127
128128
Args:
129-
timeout: Timeout value for the metadata service request.
129+
timeout_seconds: Timeout value for the metadata service request.
130130
131131
Returns:
132132
_DetectionState: DETECTED if on Azure VM, TIMEOUT if request times out,
@@ -136,7 +136,7 @@ def is_azure_vm(timeout):
136136
token_resp = requests.get(
137137
"http://169.254.169.254/metadata/instance?api-version=2021-02-01",
138138
headers={"Metadata": "true"},
139-
timeout=timeout,
139+
timeout=timeout_seconds,
140140
)
141141
return (
142142
_DetectionState.DETECTED
@@ -174,7 +174,7 @@ def is_azure_function():
174174

175175

176176
def is_managed_identity_available_on_azure_vm(
177-
timeout, resource=DEFAULT_ENTRA_SNOWFLAKE_RESOURCE
177+
timeout_seconds, resource=DEFAULT_ENTRA_SNOWFLAKE_RESOURCE
178178
):
179179
"""
180180
Check if Azure Managed Identity is available and accessible on an Azure VM.
@@ -184,7 +184,7 @@ def is_managed_identity_available_on_azure_vm(
184184
then we assume managed identity is available.
185185
186186
Args:
187-
timeout: Timeout value for the metadata service request.
187+
timeout_seconds: Timeout value for the metadata service request.
188188
resource: The Azure resource URI to request a token for.
189189
190190
Returns:
@@ -194,7 +194,7 @@ def is_managed_identity_available_on_azure_vm(
194194
endpoint = f"http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource={resource}"
195195
headers = {"Metadata": "true"}
196196
try:
197-
response = requests.get(endpoint, headers=headers, timeout=timeout)
197+
response = requests.get(endpoint, headers=headers, timeout=timeout_seconds)
198198
return (
199199
_DetectionState.DETECTED
200200
if response.status_code == 200
@@ -206,7 +206,7 @@ def is_managed_identity_available_on_azure_vm(
206206
return _DetectionState.NOT_DETECTED
207207

208208

209-
def has_azure_managed_identity(on_azure_vm, on_azure_function, timeout):
209+
def has_azure_managed_identity(on_azure_vm, on_azure_function, timeout_seconds: float):
210210
"""
211211
Determine if Azure Managed Identity is available in the current environment.
212212
@@ -219,7 +219,7 @@ def has_azure_managed_identity(on_azure_vm, on_azure_function, timeout):
219219
Args:
220220
on_azure_vm: Detection state for Azure VM.
221221
on_azure_function: Detection state for Azure Function.
222-
timeout: Timeout value for managed identity checks.
222+
timeout_seconds: Timeout value for managed identity checks.
223223
224224
Returns:
225225
_DetectionState: DETECTED if managed identity is available, TIMEOUT if
@@ -232,7 +232,7 @@ def has_azure_managed_identity(on_azure_vm, on_azure_function, timeout):
232232
else _DetectionState.NOT_DETECTED
233233
)
234234
if on_azure_vm == _DetectionState.DETECTED:
235-
return is_managed_identity_available_on_azure_vm(timeout)
235+
return is_managed_identity_available_on_azure_vm(timeout_seconds)
236236
if (
237237
on_azure_vm == _DetectionState.TIMEOUT
238238
or on_azure_function == _DetectionState.TIMEOUT
@@ -241,22 +241,24 @@ def has_azure_managed_identity(on_azure_vm, on_azure_function, timeout):
241241
return _DetectionState.NOT_DETECTED
242242

243243

244-
def is_gce_vm(timeout):
244+
def is_gce_vm(timeout_seconds: float):
245245
"""
246246
Check if the current environment is running on Google Compute Engine (GCE).
247247
248248
If we query the Google metadata server and receive a response with the
249249
"Metadata-Flavor: Google" header, then we assume we are running on GCE.
250250
251251
Args:
252-
timeout: Timeout value for the metadata service request.
252+
timeout_seconds: Timeout value for the metadata service request.
253253
254254
Returns:
255255
_DetectionState: DETECTED if on GCE, TIMEOUT if request times out,
256256
NOT_DETECTED otherwise.
257257
"""
258258
try:
259-
response = requests.get("http://metadata.google.internal", timeout=timeout)
259+
response = requests.get(
260+
"http://metadata.google.internal", timeout=timeout_seconds
261+
)
260262
return (
261263
_DetectionState.DETECTED
262264
if response.headers.get("Metadata-Flavor") == "Google"
@@ -306,15 +308,15 @@ def is_gce_cloud_run_job():
306308
)
307309

308310

309-
def has_gcp_identity(timeout):
311+
def has_gcp_identity(timeout_seconds: float):
310312
"""
311313
Check if the current environment has a valid Google Cloud Platform identity.
312314
313315
If we query the GCP metadata service for the default service account email
314316
and receive a non-empty response, then we assume we have a valid GCP identity.
315317
316318
Args:
317-
timeout: Timeout value for the metadata service request.
319+
timeout_seconds: Timeout value for the metadata service request.
318320
319321
Returns:
320322
_DetectionState: DETECTED if valid GCP identity exists, TIMEOUT if request
@@ -324,7 +326,7 @@ def has_gcp_identity(timeout):
324326
response = requests.get(
325327
"http://metadata/computeMetadata/v1/instance/service-accounts/default/email",
326328
headers={"Metadata-Flavor": "Google"},
327-
timeout=timeout,
329+
timeout=timeout_seconds,
328330
)
329331
response.raise_for_status()
330332
return (
@@ -353,39 +355,39 @@ def is_github_action():
353355
)
354356

355357

356-
def detect_platforms(timeout: int | float | None) -> list[str]:
358+
def detect_platforms(timeout_seconds: float | None) -> list[str]:
357359
"""
358360
Detect all potential platforms that the current environment may be running on.
359361
360362
Args:
361-
timeout: Timeout value for platform detection requests. Defaults to 0.2 seconds
363+
timeout_seconds: Timeout value for platform detection requests. Defaults to 0.2 seconds
362364
if None is provided.
363365
364366
Returns:
365367
list[str]: List of detected platform names. Platforms that timed out will have
366368
"_timeout" suffix appended to their name.
367369
"""
368-
if timeout is None:
369-
timeout = 0.2
370+
if timeout_seconds is None:
371+
timeout_seconds = 0.2
370372

371373
with ThreadPoolExecutor(max_workers=10) as executor:
372374
futures = {
373-
"is_ec2_instance": executor.submit(is_ec2_instance, timeout),
375+
"is_ec2_instance": executor.submit(is_ec2_instance, timeout_seconds),
374376
"is_aws_lambda": executor.submit(is_aws_lambda),
375-
"has_aws_identity": executor.submit(has_aws_identity, timeout),
376-
"is_azure_vm": executor.submit(is_azure_vm, timeout),
377+
"has_aws_identity": executor.submit(has_aws_identity, timeout_seconds),
378+
"is_azure_vm": executor.submit(is_azure_vm, timeout_seconds),
377379
"is_azure_function": executor.submit(is_azure_function),
378-
"is_gce_vm": executor.submit(is_gce_vm, timeout),
380+
"is_gce_vm": executor.submit(is_gce_vm, timeout_seconds),
379381
"is_gce_cloud_run_service": executor.submit(is_gce_cloud_run_service),
380382
"is_gce_cloud_run_job": executor.submit(is_gce_cloud_run_job),
381-
"has_gcp_identity": executor.submit(has_gcp_identity, timeout),
383+
"has_gcp_identity": executor.submit(has_gcp_identity, timeout_seconds),
382384
"is_github_action": executor.submit(is_github_action),
383385
}
384386

385387
platforms = {key: future.result() for key, future in futures.items()}
386388

387389
platforms["azure_managed_identity"] = has_azure_managed_identity(
388-
platforms["is_azure_vm"], platforms["is_azure_function"], timeout
390+
platforms["is_azure_vm"], platforms["is_azure_function"], timeout_seconds
389391
)
390392

391393
detected_platforms = []

0 commit comments

Comments
 (0)