-
Notifications
You must be signed in to change notification settings - Fork 146
Expand file tree
/
Copy pathconftest.py
More file actions
462 lines (415 loc) · 14.7 KB
/
conftest.py
File metadata and controls
462 lines (415 loc) · 14.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
#!/usr/bin/env python3
#
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#
import os
from logging import getLogger
from typing import Dict
import pytest
import snowflake.connector
from snowflake.snowpark import Session
from snowflake.snowpark._internal.utils import set_ast_state, AstFlagSource
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.mock._connection import MockServerConnection
from tests.ast.ast_test_utils import (
close_full_ast_validation_mode,
setup_full_ast_validation_mode,
)
from tests.parameters import CONNECTION_PARAMETERS
from tests.utils import (
TEST_SCHEMA,
TestFiles,
Utils,
running_on_jenkins,
running_on_public_ci,
)
RUNNING_ON_GH = os.getenv("GITHUB_ACTIONS") == "true"
RUNNING_ON_JENKINS = "JENKINS_HOME" in os.environ
test_dir = os.path.dirname(__file__)
test_data_dir = os.path.join(test_dir, "cassettes")
_logger = getLogger(__name__)
SNOWFLAKE_CREDENTIAL_HEADER_FIELDS = [
"Authorization",
"x-amz-server-side-encryption-customer-key-MD5",
"x-amz-server-side-encryption-customer-key-md5",
"x-amz-server-side-encryption-customer-key",
"x-amz-server-side-encryption-customer-algorithm",
"x-amz-id-2",
"x-amz-request-id",
"x-amz-version-id",
]
def print_help() -> None:
print(
"""Connection parameter must be specified in parameters.py,
for example:
CONNECTION_PARAMETERS = {
'account': 'testaccount',
'user': 'user1',
'password': 'test',
'database': 'testdb',
'schema': 'public',
}
"""
)
def set_up_external_access_integration_resources(
session,
rule1,
rule2,
rule3,
key1,
key2,
key3,
integration1,
integration2,
integration3,
):
try:
# IMPORTANT SETUP NOTES: the test role needs to be granted the creation privilege
# log into the admin account and run the following sql to grant the privilege
# GRANT CREATE INTEGRATION ON ACCOUNT TO ROLE <test_role>;
# prepare external access resource
session.sql(
f"""
CREATE IF NOT EXISTS NETWORK RULE {rule1}
MODE = EGRESS
TYPE = HOST_PORT
VALUE_LIST = ('www.google.com');
"""
).collect()
session.sql(
f"""
CREATE IF NOT EXISTS NETWORK RULE {rule2}
MODE = EGRESS
TYPE = HOST_PORT
VALUE_LIST = ('www.microsoft.com');
"""
).collect()
session.sql(
f"""
CREATE IF NOT EXISTS NETWORK RULE {rule3}
MODE = EGRESS
TYPE = HOST_PORT
VALUE_LIST = ('www.amazon.com');
"""
).collect()
session.sql(
f"""
CREATE IF NOT EXISTS SECRET {key1}
TYPE = GENERIC_STRING
SECRET_STRING = 'replace-with-your-api-key';
"""
).collect()
session.sql(
f"""
CREATE IF NOT EXISTS SECRET {key2}
TYPE = GENERIC_STRING
SECRET_STRING = 'replace-with-your-api-key_2';
"""
).collect()
session.sql(
f"""
CREATE IF NOT EXISTS SECRET {key3}
TYPE = PASSWORD
USERNAME = 'replace-with-your-username';
PASSWORD = 'replace-with-your-password';
"""
).collect()
session.sql(
f"""
CREATE IF NOT EXISTS EXTERNAL ACCESS INTEGRATION {integration1}
ALLOWED_NETWORK_RULES = ({rule1})
ALLOWED_AUTHENTICATION_SECRETS = ({key1})
ENABLED = true;
"""
).collect()
session.sql(
f"""
CREATE IF NOT EXISTS EXTERNAL ACCESS INTEGRATION {integration2}
ALLOWED_NETWORK_RULES = ({rule2})
ALLOWED_AUTHENTICATION_SECRETS = ({key2})
ENABLED = true;
"""
).collect()
session.sql(
f"""
CREATE IF NOT EXISTS EXTERNAL ACCESS INTEGRATION {integration3}
ALLOWED_NETWORK_RULES = ({rule3})
ALLOWED_AUTHENTICATION_SECRETS = ({key3})
ENABLED = true;
"""
).collect()
CONNECTION_PARAMETERS["external_access_rule1"] = rule1
CONNECTION_PARAMETERS["external_access_rule2"] = rule2
CONNECTION_PARAMETERS["external_access_rule3"] = rule3
CONNECTION_PARAMETERS["external_access_key1"] = key1
CONNECTION_PARAMETERS["external_access_key2"] = key2
CONNECTION_PARAMETERS["external_access_key3"] = key3
CONNECTION_PARAMETERS["external_access_integration1"] = integration1
CONNECTION_PARAMETERS["external_access_integration2"] = integration2
CONNECTION_PARAMETERS["external_access_integration3"] = integration3
except SnowparkSQLException:
# GCP currently does not support external access integration
# we can remove the exception once the integration is available on GCP
pass
session.sql(
"CREATE API INTEGRATION IF NOT EXISTS "
"SNOWPARK_PYTHON_TEST_INTEGRATION API_PROVIDER = pypi "
"ENABLED = TRUE"
).collect()
session.sql(
"CREATE ARTIFACT REPOSITORY IF NOT EXISTS "
f'{CONNECTION_PARAMETERS["database"]}.{CONNECTION_PARAMETERS["schema"]}.SNOWPARK_PYTHON_TEST_REPOSITORY '
"TYPE = pip API_INTEGRATION = SNOWPARK_PYTHON_TEST_INTEGRATION"
).collect()
def clean_up_external_access_integration_resources():
CONNECTION_PARAMETERS.pop("external_access_rule1", None)
CONNECTION_PARAMETERS.pop("external_access_rule2", None)
CONNECTION_PARAMETERS.pop("external_access_rule3", None)
CONNECTION_PARAMETERS.pop("external_access_key1", None)
CONNECTION_PARAMETERS.pop("external_access_key2", None)
CONNECTION_PARAMETERS.pop("external_access_key3", None)
CONNECTION_PARAMETERS.pop("external_access_integration1", None)
CONNECTION_PARAMETERS.pop("external_access_integration2", None)
CONNECTION_PARAMETERS.pop("external_access_integration3", None)
def set_up_dataframe_processor_parameters(
session, dataframe_processor_pkg_version, dataframe_processor_location
):
def set_param_value(param, value):
if value is not None:
try:
session.sql(
f"alter session set {param} = '{value}';", _emit_ast=False
).collect(_emit_ast=False)
except Exception as ex:
_logger.error(f"Failed to set {param}, ex={ex}")
set_param_value("DATAFRAME_PROCESSOR_PKG_VERSION", dataframe_processor_pkg_version)
set_param_value("DATAFRAME_PROCESSOR_LOCATION", dataframe_processor_location)
@pytest.fixture(scope="session")
def db_parameters(local_testing_mode) -> Dict[str, str]:
# If its running on our public CI or Jenkins, replace the schema
if running_on_public_ci() or running_on_jenkins():
# tests related to external access integration requires secrets, network rule to be created ahead
# we keep the information of the existing schema to refer to those objects
CONNECTION_PARAMETERS["schema_with_secret"], CONNECTION_PARAMETERS["schema"] = (
CONNECTION_PARAMETERS["schema"],
TEST_SCHEMA,
)
else:
CONNECTION_PARAMETERS["schema_with_secret"] = CONNECTION_PARAMETERS["schema"]
CONNECTION_PARAMETERS["local_testing"] = local_testing_mode
CONNECTION_PARAMETERS["session_parameters"] = {
"PYTHON_SNOWPARK_GENERATE_MULTILINE_QUERIES": True
}
return CONNECTION_PARAMETERS
@pytest.fixture(scope="session")
def resources_path() -> str:
return os.path.normpath(os.path.join(os.path.dirname(__file__), "../resources"))
@pytest.fixture(scope="session")
def connection(db_parameters, local_testing_mode):
if local_testing_mode:
yield MockServerConnection(options={"disable_local_testing_telemetry": True})
else:
_keys = [
"user",
"password",
"private_key_file",
"private_key_file_pwd",
"host",
"port",
"database",
"schema",
"account",
"protocol",
"role",
"warehouse",
]
with snowflake.connector.connect(
**{
k: db_parameters[k]
for k in _keys
if k in db_parameters and db_parameters[k] is not None
}
) as con:
yield con
@pytest.fixture(scope="session")
def is_sample_data_available(connection, local_testing_mode) -> bool:
if local_testing_mode:
return False
return (
len(
connection.cursor()
.execute("show databases like 'SNOWFLAKE_SAMPLE_DATA'")
.fetchall()
)
> 0
)
@pytest.fixture(scope="session", autouse=True)
def test_schema(connection, local_testing_mode) -> None:
"""Set up and tear down the test schema. This is automatically called per test session."""
if local_testing_mode:
yield
else:
with connection.cursor() as cursor:
cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {TEST_SCHEMA}")
# This is needed for test_get_schema_database_works_after_use_role in test_session_suite
cursor.execute(
f"GRANT ALL PRIVILEGES ON SCHEMA {TEST_SCHEMA} TO ROLE PUBLIC"
)
yield
cursor.execute(f"DROP SCHEMA IF EXISTS {TEST_SCHEMA}")
@pytest.fixture(scope="module")
def session(
db_parameters,
resources_path,
sql_simplifier_enabled,
local_testing_mode,
cte_optimization_enabled,
join_alias_fix,
ast_enabled,
dataframe_processor_pkg_version,
dataframe_processor_location,
validate_ast,
unparser_jar,
):
set_ast_state(AstFlagSource.TEST, ast_enabled)
rule1 = "snowpark_python_test_rule1"
rule2 = "snowpark_python_test_rule2"
rule3 = "snowpark_python_test_rule3"
key1 = "snowpark_python_test_key1"
key2 = "snowpark_python_test_key2"
key3 = "snowpark_python_test_key3"
integration1 = "snowpark_python_test_integration1"
integration2 = "snowpark_python_test_integration2"
integration3 = "snowpark_python_test_integration3"
session = (
Session.builder.configs(db_parameters)
.config("local_testing", local_testing_mode)
.config(
"session_parameters",
{"feature_interval_types": "ENABLED"},
)
.create()
)
set_up_dataframe_processor_parameters(
session, dataframe_processor_pkg_version, dataframe_processor_location
)
session.sql_simplifier_enabled = sql_simplifier_enabled
session._cte_optimization_enabled = cte_optimization_enabled
session._join_alias_fix = join_alias_fix
session.ast_enabled = ast_enabled
if not session._generate_multiline_queries:
session._enable_multiline_queries()
if (RUNNING_ON_GH or RUNNING_ON_JENKINS) and not local_testing_mode:
set_up_external_access_integration_resources(
session,
rule1,
rule2,
rule3,
key1,
key2,
key3,
integration1,
integration2,
integration3,
)
if validate_ast:
full_ast_validation_listener = setup_full_ast_validation_mode(
session, db_parameters, unparser_jar
)
# TODO: SNOW-2346239: Set parameter on user level instead of in config file
if not local_testing_mode:
session.sql(
"alter session set ENABLE_EXTRACTION_PUSHDOWN_EXTERNAL_PARQUET_FOR_COPY_PHASE_I='Track';"
).collect()
session.sql("alter session set ENABLE_ROW_ACCESS_POLICY=true").collect()
# TODO: remove
session.sql(
"ALTER SESSION SET ENABLE_DEFAULT_PYTHON_ARTIFACT_REPOSITORY = true"
).collect()
try:
yield session
finally:
if validate_ast:
close_full_ast_validation_mode(full_ast_validation_listener)
if (RUNNING_ON_GH or RUNNING_ON_JENKINS) and not local_testing_mode:
clean_up_external_access_integration_resources()
session.close()
@pytest.fixture(scope="function")
def profiler_session(
db_parameters,
resources_path,
sql_simplifier_enabled,
local_testing_mode,
cte_optimization_enabled,
):
rule1 = "snowpark_python_profiler_test_rule1"
rule2 = "snowpark_python_profiler_test_rule2"
rule3 = "snowpark_python_profiler_test_rule3"
key1 = "snowpark_python_profiler_test_key1"
key2 = "snowpark_python_profiler_test_key2"
key3 = "snowpark_python_profiler_test_key3"
integration1 = "snowpark_python_profiler_test_integration1"
integration2 = "snowpark_python_profiler_test_integration2"
integration3 = "snowpark_python_profiler_test_integration3"
session = (
Session.builder.configs(db_parameters)
.config("local_testing", local_testing_mode)
.create()
)
session.sql_simplifier_enabled = sql_simplifier_enabled
session._cte_optimization_enabled = cte_optimization_enabled
if RUNNING_ON_GH and not local_testing_mode:
set_up_external_access_integration_resources(
session,
rule1,
rule2,
rule3,
key1,
key2,
key3,
integration1,
integration2,
integration3,
)
try:
yield session
finally:
if RUNNING_ON_GH and not local_testing_mode:
clean_up_external_access_integration_resources()
session.close()
@pytest.fixture(scope="function")
def temp_schema(connection, session, local_testing_mode) -> None:
"""Set up and tear down a temp schema for cross-schema test.
This is automatically called per test module."""
temp_schema_name = Utils.get_fully_qualified_temp_schema(session)
if local_testing_mode:
yield temp_schema_name
else:
with connection.cursor() as cursor:
cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {temp_schema_name}")
# This is needed for test_get_schema_database_works_after_use_role in test_session_suite
cursor.execute(
f"GRANT ALL PRIVILEGES ON SCHEMA {temp_schema_name} TO ROLE PUBLIC"
)
yield temp_schema_name
cursor.execute(f"DROP SCHEMA IF EXISTS {temp_schema_name}")
@pytest.fixture(scope="module")
def temp_stage(session, resources_path, local_testing_mode):
tmp_stage_name = Utils.random_stage_name()
test_files = TestFiles(resources_path)
if not local_testing_mode:
Utils.create_stage(session, tmp_stage_name, is_temporary=True)
Utils.upload_to_stage(
session, tmp_stage_name, test_files.test_file_parquet, compress=False
)
yield tmp_stage_name
if not local_testing_mode:
Utils.drop_stage(session, tmp_stage_name)
@pytest.fixture(scope="function", autouse=True)
def clear_session_ast_batch_on_validate_ast(session, validate_ast):
"""
After each test, flush the AST batch so it does not pollute the next test validation.
"""
yield
if validate_ast:
session._ast_batch.flush()