@@ -21,7 +21,7 @@ class _DetectionState(Enum):
2121 TIMEOUT = "timeout"
2222
2323
24- def is_ec2_instance (timeout_seconds : float ):
24+ def is_ec2_instance (platform_detection_timeout_seconds : float ):
2525 """
2626 Check if the current environment is running on an AWS EC2 instance.
2727
@@ -31,13 +31,15 @@ def is_ec2_instance(timeout_seconds: float):
3131 It will ignore the token if on IMDSv1 and use the token if on IMDSv2.
3232
3333 Args:
34- timeout_seconds : Timeout value for the metadata service request.
34+ platform_detection_timeout_seconds : Timeout value for the metadata service request.
3535
3636 Returns:
3737 _DetectionState: DETECTED if running on EC2, NOT_DETECTED otherwise.
3838 """
3939 try :
40- fetcher = IMDSFetcher (timeout = timeout_seconds , num_attempts = 1 )
40+ fetcher = IMDSFetcher (
41+ timeout = platform_detection_timeout_seconds , num_attempts = 1
42+ )
4143 document = fetcher ._get_request (
4244 "/latest/dynamic/instance-identity/document" ,
4345 None ,
@@ -71,7 +73,7 @@ def is_aws_lambda():
7173
7274def is_valid_arn_for_wif (arn : str ) -> bool :
7375 """
74- Validate if an AWS ARN is suitable for Web Identity Federation (WIF).
76+ Validate if an AWS ARN is suitable for use with Snowflake's Workload Identity Federation (WIF).
7577
7678 Args:
7779 arn: The AWS ARN string to validate.
@@ -86,47 +88,46 @@ def is_valid_arn_for_wif(arn: str) -> bool:
8688 return any (re .match (p , arn ) for p in patterns )
8789
8890
89- def has_aws_identity (timeout_seconds : float ):
91+ def has_aws_identity (platform_detection_timeout_seconds : float ):
9092 """
9193 Check if the current environment has a valid AWS identity for authentication.
9294
9395 If we retrieve an ARN from the caller identity and it is a valid WIF ARN,
9496 then we assume we have a valid AWS identity for authentication.
9597
9698 Args:
97- timeout_seconds : Timeout value for AWS API calls.
99+ platform_detection_timeout_seconds : Timeout value for AWS API calls.
98100
99101 Returns:
100102 _DetectionState: DETECTED if valid AWS identity exists, NOT_DETECTED otherwise.
101103 """
102104 try :
103105 config = Config (
104- connect_timeout = timeout_seconds ,
105- read_timeout = timeout_seconds ,
106+ connect_timeout = platform_detection_timeout_seconds ,
107+ read_timeout = platform_detection_timeout_seconds ,
106108 retries = {"total_max_attempts" : 1 },
107109 )
108110 caller_identity = boto3 .client ("sts" , config = config ).get_caller_identity ()
109111 if not caller_identity or "Arn" not in caller_identity :
110112 return _DetectionState .NOT_DETECTED
111- else :
112- return (
113- _DetectionState .DETECTED
114- if is_valid_arn_for_wif (caller_identity ["Arn" ])
115- else _DetectionState .NOT_DETECTED
116- )
113+ return (
114+ _DetectionState .DETECTED
115+ if is_valid_arn_for_wif (caller_identity ["Arn" ])
116+ else _DetectionState .NOT_DETECTED
117+ )
117118 except Exception :
118119 return _DetectionState .NOT_DETECTED
119120
120121
121- def is_azure_vm (timeout_seconds : float ):
122+ def is_azure_vm (platform_detection_timeout_seconds : float ):
122123 """
123124 Check if the current environment is running on an Azure Virtual Machine.
124125
125126 If we query the Azure Instance Metadata Service and receive an HTTP 200 response,
126127 then we assume we are running on an Azure VM.
127128
128129 Args:
129- timeout_seconds : Timeout value for the metadata service request.
130+ platform_detection_timeout_seconds : Timeout value for the metadata service request.
130131
131132 Returns:
132133 _DetectionState: DETECTED if on Azure VM, TIMEOUT if request times out,
@@ -136,7 +137,7 @@ def is_azure_vm(timeout_seconds: float):
136137 token_resp = requests .get (
137138 "http://169.254.169.254/metadata/instance?api-version=2021-02-01" ,
138139 headers = {"Metadata" : "True" },
139- timeout = timeout_seconds ,
140+ timeout = platform_detection_timeout_seconds ,
140141 )
141142 return (
142143 _DetectionState .DETECTED
@@ -174,7 +175,7 @@ def is_azure_function():
174175
175176
176177def is_managed_identity_available_on_azure_vm (
177- timeout_seconds , resource = "https://management.azure.com/ "
178+ platform_detection_timeout_seconds , resource = "https://management.azure.com"
178179):
179180 """
180181 Check if Azure Managed Identity is available and accessible on an Azure VM.
@@ -184,7 +185,7 @@ def is_managed_identity_available_on_azure_vm(
184185 then we assume managed identity is available.
185186
186187 Args:
187- timeout_seconds : Timeout value for the metadata service request.
188+ platform_detection_timeout_seconds : Timeout value for the metadata service request.
188189 resource: The Azure resource URI to request a token for.
189190
190191 Returns:
@@ -194,7 +195,9 @@ def is_managed_identity_available_on_azure_vm(
194195 endpoint = f"http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource={ resource } "
195196 headers = {"Metadata" : "true" }
196197 try :
197- response = requests .get (endpoint , headers = headers , timeout = timeout_seconds )
198+ response = requests .get (
199+ endpoint , headers = headers , timeout = platform_detection_timeout_seconds
200+ )
198201 return (
199202 _DetectionState .DETECTED
200203 if response .status_code == 200
@@ -206,62 +209,56 @@ def is_managed_identity_available_on_azure_vm(
206209 return _DetectionState .NOT_DETECTED
207210
208211
209- def has_azure_managed_identity (timeout_seconds : float ):
212+ def is_managed_identity_available_on_azure_function ():
213+ return bool (os .environ .get ("IDENTITY_HEADER" ))
214+
215+
216+ def has_azure_managed_identity (platform_detection_timeout_seconds : float ):
210217 """
211218 Determine if Azure Managed Identity is available in the current environment.
212219
213220 If we are on Azure Functions and the IDENTITY_HEADER environment variable exists,
214221 then we assume managed identity is available.
215222 If we are on an Azure VM and can mint an access token from the managed identity endpoint,
216223 then we assume managed identity is available.
217- Assumes timeout state if either VM or Function detection timed out.
218- Otherwise, assumes it is not available
224+ Handles Azure Functions first since the checks are faster
225+ Handles Azure VM checks second since they involve network calls.
219226
220227 Args:
221- timeout_seconds : Timeout value for managed identity checks.
228+ platform_detection_timeout_seconds : Timeout value for managed identity checks.
222229
223230 Returns:
224231 _DetectionState: DETECTED if managed identity is available, TIMEOUT if
225232 detection timed out, NOT_DETECTED otherwise.
226233 """
227- has_azure_function_managed_identity = (
228- _DetectionState .DETECTED
229- if os .environ .get ("IDENTITY_HEADER" )
230- else _DetectionState .NOT_DETECTED
231- )
232- has_azure_vm_managed_identity = is_managed_identity_available_on_azure_vm (
233- timeout_seconds
234- )
235- if (
236- has_azure_vm_managed_identity == _DetectionState .DETECTED
237- or has_azure_function_managed_identity == _DetectionState .DETECTED
238- ):
239- return _DetectionState .DETECTED
240- if (
241- has_azure_vm_managed_identity == _DetectionState .TIMEOUT
242- or has_azure_function_managed_identity == _DetectionState .TIMEOUT
243- ):
244- return _DetectionState .TIMEOUT
245- return _DetectionState .NOT_DETECTED
234+ # short circuit early to save on latency and avoid minting an unnecessary token
235+ if is_azure_function () == _DetectionState .DETECTED :
236+ return (
237+ _DetectionState .DETECTED
238+ if is_managed_identity_available_on_azure_function ()
239+ else _DetectionState .NOT_DETECTED
240+ )
241+ return is_managed_identity_available_on_azure_vm (platform_detection_timeout_seconds )
246242
247243
248- def is_gce_vm (timeout_seconds : float ):
244+ def is_gce_vm (platform_detection_timeout_seconds : float ):
249245 """
250246 Check if the current environment is running on Google Compute Engine (GCE).
251247
252248 If we query the Google metadata server and receive a response with the
253249 "Metadata-Flavor: Google" header, then we assume we are running on GCE.
254250
255251 Args:
256- timeout_seconds : Timeout value for the metadata service request.
252+ platform_detection_timeout_seconds : Timeout value for the metadata service request.
257253
258254 Returns:
259255 _DetectionState: DETECTED if on GCE, TIMEOUT if request times out,
260256 NOT_DETECTED otherwise.
261257 """
262258 try :
263259 response = requests .get (
264- "http://metadata.google.internal" , timeout = timeout_seconds
260+ "http://metadata.google.internal" ,
261+ timeout = platform_detection_timeout_seconds ,
265262 )
266263 return (
267264 _DetectionState .DETECTED
@@ -274,7 +271,7 @@ def is_gce_vm(timeout_seconds: float):
274271 return _DetectionState .NOT_DETECTED
275272
276273
277- def is_gce_cloud_run_service ():
274+ def is_gcp_cloud_run_service ():
278275 """
279276 Check if the current environment is running in Google Cloud Run service.
280277
@@ -293,7 +290,7 @@ def is_gce_cloud_run_service():
293290 )
294291
295292
296- def is_gce_cloud_run_job ():
293+ def is_gcp_cloud_run_job ():
297294 """
298295 Check if the current environment is running in Google Cloud Run job.
299296
@@ -312,15 +309,15 @@ def is_gce_cloud_run_job():
312309 )
313310
314311
315- def has_gcp_identity (timeout_seconds : float ):
312+ def has_gcp_identity (platform_detection_timeout_seconds : float ):
316313 """
317314 Check if the current environment has a valid Google Cloud Platform identity.
318315
319316 If we query the GCP metadata service for the default service account email
320317 and receive a non-empty response, then we assume we have a valid GCP identity.
321318
322319 Args:
323- timeout_seconds : Timeout value for the metadata service request.
320+ platform_detection_timeout_seconds : Timeout value for the metadata service request.
324321 Returns:
325322 _DetectionState: DETECTED if valid GCP identity exists, TIMEOUT if request
326323 times out, NOT_DETECTED otherwise.
@@ -329,7 +326,7 @@ def has_gcp_identity(timeout_seconds: float):
329326 response = requests .get (
330327 "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/email" ,
331328 headers = {"Metadata-Flavor" : "Google" },
332- timeout = timeout_seconds ,
329+ timeout = platform_detection_timeout_seconds ,
333330 )
334331 return (
335332 _DetectionState .DETECTED
@@ -360,13 +357,13 @@ def is_github_action():
360357
361358
362359@cache
363- def detect_platforms (timeout_seconds : float | None ) -> list [str ]:
360+ def detect_platforms (platform_detection_timeout_seconds : float | None ) -> list [str ]:
364361 """
365362 Detect all potential platforms that the current environment may be running on.
366363 Swallows all exceptions and returns an empty list if any exception occurs to not affect main driver functionality.
367364
368365 Args:
369- timeout_seconds : Timeout value for platform detection requests. Defaults to 0.2 seconds
366+ platform_detection_timeout_seconds : Timeout value for platform detection requests. Defaults to 0.2 seconds
370367 if None is provided.
371368
372369 Returns:
@@ -375,27 +372,42 @@ def detect_platforms(timeout_seconds: float | None) -> list[str]:
375372 exception occurs during detection.
376373 """
377374 try :
378- if timeout_seconds is None :
379- timeout_seconds = 0.2
380-
381- with ThreadPoolExecutor (max_workers = 10 ) as executor :
375+ if platform_detection_timeout_seconds is None :
376+ platform_detection_timeout_seconds = 0.2
377+
378+ # Run environment-only checks synchronously (no network calls, no threading overhead)
379+ platforms = {
380+ "is_aws_lambda" : is_aws_lambda (),
381+ "is_azure_function" : is_azure_function (),
382+ "is_gce_cloud_run_service" : is_gcp_cloud_run_service (),
383+ "is_gce_cloud_run_job" : is_gcp_cloud_run_job (),
384+ "is_github_action" : is_github_action (),
385+ }
386+
387+ # Run network-calling functions in parallel
388+ with ThreadPoolExecutor (max_workers = 6 ) as executor :
382389 futures = {
383- "is_ec2_instance" : executor .submit (is_ec2_instance , timeout_seconds ),
384- "is_aws_lambda" : executor .submit (is_aws_lambda ),
385- "has_aws_identity" : executor .submit (has_aws_identity , timeout_seconds ),
386- "is_azure_vm" : executor .submit (is_azure_vm , timeout_seconds ),
387- "is_azure_function" : executor .submit (is_azure_function ),
390+ "is_ec2_instance" : executor .submit (
391+ is_ec2_instance , platform_detection_timeout_seconds
392+ ),
393+ "has_aws_identity" : executor .submit (
394+ has_aws_identity , platform_detection_timeout_seconds
395+ ),
396+ "is_azure_vm" : executor .submit (
397+ is_azure_vm , platform_detection_timeout_seconds
398+ ),
388399 "azure_managed_identity" : executor .submit (
389- has_azure_managed_identity , timeout_seconds
400+ has_azure_managed_identity , platform_detection_timeout_seconds
401+ ),
402+ "is_gce_vm" : executor .submit (
403+ is_gce_vm , platform_detection_timeout_seconds
404+ ),
405+ "has_gcp_identity" : executor .submit (
406+ has_gcp_identity , platform_detection_timeout_seconds
390407 ),
391- "is_gce_vm" : executor .submit (is_gce_vm , timeout_seconds ),
392- "is_gce_cloud_run_service" : executor .submit (is_gce_cloud_run_service ),
393- "is_gce_cloud_run_job" : executor .submit (is_gce_cloud_run_job ),
394- "has_gcp_identity" : executor .submit (has_gcp_identity , timeout_seconds ),
395- "is_github_action" : executor .submit (is_github_action ),
396408 }
397409
398- platforms = {key : future .result () for key , future in futures .items ()}
410+ platforms . update ( {key : future .result () for key , future in futures .items ()})
399411
400412 detected_platforms = []
401413 for platform_name , detection_state in platforms .items ():
0 commit comments