Skip to content

Commit d163d32

Browse files
switched to detect platforms by adding it to the base_auth data which will get sent to login-request and eventually persisted in the SessionDPO
1 parent 44ac14c commit d163d32

File tree

3 files changed

+140
-154
lines changed

3 files changed

+140
-154
lines changed

src/snowflake/connector/auth/_auth.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
import copy
44
import json
55
import logging
6+
import os
7+
import re
68
import uuid
9+
from concurrent.futures.thread import ThreadPoolExecutor
710
from datetime import datetime, timezone
811
from threading import Thread
912
from typing import TYPE_CHECKING, Any, Callable
1013

14+
import boto3
15+
from botocore.utils import IMDSFetcher
1116
from cryptography.hazmat.backends import default_backend
1217
from cryptography.hazmat.primitives.serialization import (
1318
Encoding,
@@ -53,7 +58,9 @@
5358
)
5459
from ..sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
5560
from ..token_cache import TokenCache, TokenKey, TokenType
61+
from ..vendored import requests
5662
from ..version import VERSION
63+
from ..wif_util import DEFAULT_ENTRA_SNOWFLAKE_RESOURCE
5764
from .no_auth import AuthNoAuth
5865

5966
if TYPE_CHECKING:
@@ -101,6 +108,136 @@ def base_auth_data(
101108
network_timeout: int | None = None,
102109
socket_timeout: int | None = None,
103110
):
111+
def detect_platforms() -> list[str]:
112+
def is_ec2_instance(timeout=0.5):
113+
try:
114+
fetcher = IMDSFetcher(timeout=timeout, num_attempts=2)
115+
document = fetcher._get_request(
116+
"/latest/dynamic/instance-identity/document",
117+
None,
118+
fetcher._fetch_metadata_token(),
119+
)
120+
return bool(document.content)
121+
except Exception:
122+
return False
123+
124+
def is_aws_lambda():
125+
return "LAMBDA_TASK_ROOT" in os.environ
126+
127+
def is_valid_arn_for_wif(arn: str) -> bool:
128+
patterns = [
129+
r"^arn:[^:]+:iam::[^:]+:user/.+$",
130+
r"^arn:[^:]+:sts::[^:]+:assumed-role/.+$",
131+
]
132+
return any(re.match(p, arn) for p in patterns)
133+
134+
def has_aws_identity():
135+
try:
136+
caller_identity = boto3.client("sts").get_caller_identity()
137+
if not caller_identity or "Arn" not in caller_identity:
138+
return False
139+
else:
140+
return is_valid_arn_for_wif(caller_identity["Arn"])
141+
except Exception:
142+
return False
143+
144+
def is_azure_vm(timeout=0.5):
145+
try:
146+
token_resp = requests.get(
147+
"http://169.254.169.254/metadata/instance?api-version=2021-02-01",
148+
headers={"Metadata": "true"},
149+
timeout=timeout,
150+
)
151+
return token_resp.status_code == 200
152+
except requests.RequestException:
153+
return False
154+
155+
def is_azure_function():
156+
service_vars = [
157+
"FUNCTIONS_WORKER_RUNTIME",
158+
"FUNCTIONS_EXTENSION_VERSION",
159+
"AzureWebJobsStorage",
160+
]
161+
return all(var in os.environ for var in service_vars)
162+
163+
def is_managed_identity_available_on_azure_vm(
164+
resource=DEFAULT_ENTRA_SNOWFLAKE_RESOURCE, timeout=0.5
165+
):
166+
endpoint = f"http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource={resource}"
167+
headers = {"Metadata": "true"}
168+
try:
169+
response = requests.get(endpoint, headers=headers, timeout=timeout)
170+
return response.status_code == 200
171+
except requests.RequestException:
172+
return False
173+
174+
def has_azure_managed_identity(on_azure_vm, on_azure_function):
175+
if on_azure_function:
176+
return bool(os.environ.get("IDENTITY_HEADER"))
177+
if on_azure_vm:
178+
return is_managed_identity_available_on_azure_vm()
179+
return False
180+
181+
def is_gce_vm(timeout=0.5):
182+
try:
183+
response = requests.get(
184+
"http://metadata.google.internal", timeout=timeout
185+
)
186+
return response.headers.get("Metadata-Flavor") == "Google"
187+
except requests.RequestException:
188+
return False
189+
190+
def is_gce_cloud_run_service():
191+
service_vars = ["K_SERVICE", "K_REVISION", "K_CONFIGURATION"]
192+
return all(var in os.environ for var in service_vars)
193+
194+
def is_gce_cloud_run_job():
195+
job_vars = ["CLOUD_RUN_JOB", "CLOUD_RUN_EXECUTION"]
196+
return all(var in os.environ for var in job_vars)
197+
198+
def has_gcp_identity(timeout=2):
199+
try:
200+
response = requests.get(
201+
"http://metadata/computeMetadata/v1/instance/service-accounts/default/email",
202+
headers={"Metadata-Flavor": "Google"},
203+
timeout=timeout,
204+
)
205+
response.raise_for_status()
206+
return bool(response.text)
207+
except requests.RequestException:
208+
return False
209+
210+
def is_github_action():
211+
return "GITHUB_ACTIONS" in os.environ
212+
213+
with ThreadPoolExecutor(max_workers=10) as executor:
214+
futures = {
215+
"is_ec2_instance": executor.submit(is_ec2_instance),
216+
"is_aws_lambda": executor.submit(is_aws_lambda),
217+
"has_aws_identity": executor.submit(has_aws_identity),
218+
"is_azure_vm": executor.submit(is_azure_vm),
219+
"is_azure_function": executor.submit(is_azure_function),
220+
"is_gce_vm": executor.submit(is_gce_vm),
221+
"is_gce_cloud_run_service": executor.submit(
222+
is_gce_cloud_run_service
223+
),
224+
"is_gce_cloud_run_job": executor.submit(is_gce_cloud_run_job),
225+
"has_gcp_identity": executor.submit(has_gcp_identity),
226+
"is_github_action": executor.submit(is_github_action),
227+
}
228+
229+
platforms = {key: future.result() for key, future in futures.items()}
230+
231+
platforms["azure_managed_identity"] = has_azure_managed_identity(
232+
platforms["is_azure_vm"], platforms["is_azure_function"]
233+
)
234+
235+
detected_platforms = [
236+
platform for platform, detected in platforms.items() if detected
237+
]
238+
239+
return detected_platforms
240+
104241
return {
105242
"data": {
106243
"CLIENT_APP_ID": internal_application_name,
@@ -120,6 +257,7 @@ def base_auth_data(
120257
"LOGIN_TIMEOUT": login_timeout,
121258
"NETWORK_TIMEOUT": network_timeout,
122259
"SOCKET_TIMEOUT": socket_timeout,
260+
"PLATFORM": detect_platforms(),
123261
},
124262
},
125263
}

src/snowflake/connector/connection.py

Lines changed: 2 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,11 @@
1818
from functools import partial
1919
from io import StringIO
2020
from logging import getLogger
21-
from threading import Lock, Thread
21+
from threading import Lock
2222
from types import TracebackType
2323
from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence
2424
from uuid import UUID
2525

26-
import boto3
27-
from botocore.utils import IMDSFetcher
2826
from cryptography.hazmat.backends import default_backend
2927
from cryptography.hazmat.primitives import serialization
3028
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
@@ -121,8 +119,7 @@
121119
from .time_util import HeartBeatTimer, get_time_millis
122120
from .url_util import extract_top_level_domain_from_hostname
123121
from .util_text import construct_hostname, parse_account, split_statements
124-
from .vendored import requests
125-
from .wif_util import DEFAULT_ENTRA_SNOWFLAKE_RESOURCE, AttestationProvider
122+
from .wif_util import AttestationProvider
126123

127124
DEFAULT_CLIENT_PREFETCH_THREADS = 4
128125
MAX_CLIENT_PREFETCH_THREADS = 10
@@ -277,10 +274,6 @@ def _get_private_bytes_from_file(
277274
True,
278275
bool,
279276
), # Whether to log imported packages in telemetry
280-
"log_platform_in_telemetry": (
281-
True,
282-
bool,
283-
), # Whether to log platform in telemetry
284277
"disable_query_context_cache": (
285278
False,
286279
bool,
@@ -387,133 +380,6 @@ class TypeAndBinding(NamedTuple):
387380
binding: str | None
388381

389382

390-
def detect_platforms() -> list[str]:
391-
def is_ec2_instance(timeout=0.5):
392-
try:
393-
fetcher = IMDSFetcher(timeout=timeout, num_attempts=2)
394-
document = fetcher._get_request(
395-
"/latest/dynamic/instance-identity/document",
396-
None,
397-
fetcher._fetch_metadata_token(),
398-
)
399-
return bool(document.content)
400-
except Exception:
401-
return False
402-
403-
def is_aws_lambda():
404-
return "LAMBDA_TASK_ROOT" in os.environ
405-
406-
def is_valid_arn_for_wif(arn: str) -> bool:
407-
patterns = [
408-
r"^arn:[^:]+:iam::[^:]+:user/.+$",
409-
r"^arn:[^:]+:sts::[^:]+:assumed-role/.+$",
410-
]
411-
return any(re.match(p, arn) for p in patterns)
412-
413-
def has_aws_identity():
414-
try:
415-
caller_identity = boto3.client("sts").get_caller_identity()
416-
if not caller_identity or "Arn" not in caller_identity:
417-
return False
418-
else:
419-
return is_valid_arn_for_wif(caller_identity["Arn"])
420-
except Exception:
421-
return False
422-
423-
def is_azure_vm(timeout=0.5):
424-
try:
425-
token_resp = requests.get(
426-
"http://169.254.169.254/metadata/instance?api-version=2021-02-01",
427-
headers={"Metadata": "true"},
428-
timeout=timeout,
429-
)
430-
return token_resp.status_code == 200
431-
except requests.RequestException:
432-
return False
433-
434-
def is_azure_function():
435-
service_vars = [
436-
"FUNCTIONS_WORKER_RUNTIME",
437-
"FUNCTIONS_EXTENSION_VERSION",
438-
"AzureWebJobsStorage",
439-
]
440-
return all(var in os.environ for var in service_vars)
441-
442-
def is_managed_identity_available_on_azure_vm(
443-
resource=DEFAULT_ENTRA_SNOWFLAKE_RESOURCE, timeout=0.5
444-
):
445-
endpoint = f"http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource={resource}"
446-
headers = {"Metadata": "true"}
447-
try:
448-
response = requests.get(endpoint, headers=headers, timeout=timeout)
449-
return response.status_code == 200
450-
except requests.RequestException:
451-
return False
452-
453-
def has_azure_managed_identity(on_azure_vm, on_azure_function):
454-
if on_azure_function:
455-
return bool(os.environ.get("IDENTITY_HEADER"))
456-
if on_azure_vm:
457-
return is_managed_identity_available_on_azure_vm()
458-
return False
459-
460-
def is_gce_vm(timeout=0.5):
461-
try:
462-
response = requests.get("http://metadata.google.internal", timeout=timeout)
463-
return response.headers.get("Metadata-Flavor") == "Google"
464-
except requests.RequestException:
465-
return False
466-
467-
def is_gce_cloud_run_service():
468-
service_vars = ["K_SERVICE", "K_REVISION", "K_CONFIGURATION"]
469-
return all(var in os.environ for var in service_vars)
470-
471-
def is_gce_cloud_run_job():
472-
job_vars = ["CLOUD_RUN_JOB", "CLOUD_RUN_EXECUTION"]
473-
return all(var in os.environ for var in job_vars)
474-
475-
def has_gcp_identity(timeout=2):
476-
try:
477-
response = requests.get(
478-
"http://metadata/computeMetadata/v1/instance/service-accounts/default/email",
479-
headers={"Metadata-Flavor": "Google"},
480-
timeout=timeout,
481-
)
482-
response.raise_for_status()
483-
return bool(response.text)
484-
except requests.RequestException:
485-
return False
486-
487-
def is_github_action():
488-
return "GITHUB_ACTIONS" in os.environ
489-
490-
with ThreadPoolExecutor(max_workers=10) as executor:
491-
futures = {
492-
"is_ec2_instance": executor.submit(is_ec2_instance),
493-
"is_aws_lambda": executor.submit(is_aws_lambda),
494-
"has_aws_identity": executor.submit(has_aws_identity),
495-
"is_azure_vm": executor.submit(is_azure_vm),
496-
"is_azure_function": executor.submit(is_azure_function),
497-
"is_gce_vm": executor.submit(is_gce_vm),
498-
"is_gce_cloud_run_service": executor.submit(is_gce_cloud_run_service),
499-
"is_gce_cloud_run_job": executor.submit(is_gce_cloud_run_job),
500-
"has_gcp_identity": executor.submit(has_gcp_identity),
501-
"is_github_action": executor.submit(is_github_action),
502-
}
503-
504-
platforms = {key: future.result() for key, future in futures.items()}
505-
506-
platforms["azure_managed_identity"] = has_azure_managed_identity(
507-
platforms["is_azure_vm"], platforms["is_azure_function"]
508-
)
509-
510-
detected_platforms = [
511-
platform for platform, detected in platforms.items() if detected
512-
]
513-
514-
return detected_platforms
515-
516-
517383
class SnowflakeConnection:
518384
"""Implementation of the connection object for the Snowflake Database.
519385
@@ -682,8 +548,6 @@ def __init__(
682548

683549
# get the imported modules from sys.modules
684550
self._log_telemetry_imported_packages()
685-
# log the platform of the client
686-
Thread(target=self._log_telemetry_platform_info(), daemon=True).start()
687551
# check SNOW-1218851 for long term improvement plan to refactor ocsp code
688552
atexit.register(self._close_at_exit)
689553

@@ -2339,20 +2203,6 @@ def _log_telemetry_imported_packages(self) -> None:
23392203
)
23402204
)
23412205

2342-
def _log_telemetry_platform_info(self) -> None:
2343-
if self._log_platform_in_telemetry:
2344-
ts = get_time_millis()
2345-
self._log_telemetry(
2346-
TelemetryData.from_telemetry_data_dict(
2347-
from_dict={
2348-
TelemetryField.KEY_TYPE.value: TelemetryField.PLATFORM_INFO.value,
2349-
TelemetryField.KEY_VALUE.value: str(detect_platforms()),
2350-
},
2351-
timestamp=ts,
2352-
connection=self,
2353-
)
2354-
)
2355-
23562206
def is_valid(self) -> bool:
23572207
"""This function tries to answer the question: Is this connection still good for sending queries?
23582208
Attempts to validate the connections both on the TCP/IP and Session levels."""

src/snowflake/connector/telemetry.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ class TelemetryField(Enum):
3737
PANDAS_WRITE = "client_write_pandas"
3838
# imported packages along with client
3939
IMPORTED_PACKAGES = "client_imported_packages"
40-
# platform information describing where the client is running (AWS EC2, GCP VM, Azure function, etc)
41-
PLATFORM_INFO = "client_platform_info"
4240
# multi-statement usage
4341
MULTI_STATEMENT = "client_multi_statement_query"
4442
# Keys for telemetry data sent through either in-band or out-of-band telemetry

0 commit comments

Comments
 (0)